diff --git a/Makefile b/Makefile index d75e10abe..b797345f5 100644 --- a/Makefile +++ b/Makefile @@ -6,11 +6,7 @@ ifndef PY PY = 3 endif -FORMAT_ENFORCE_DIRS = state/ -FORMAT_EXCLUDE_REGEX = '.*' -FORMAT_EXCLUDE_GLOB = '*' -FORMAT_LINE_LENGTH = 80 - +LINT_ENFORCE_DIRS = ./bin ./mig/lib ./sbin ./tests LOCAL_PYTHON_BIN = './envhelp/lpython' ifdef PYTHON_BIN @@ -42,15 +38,10 @@ ifneq ($(MIG_ENV),'local') endif @make format-python -.PHONY:format-python +.PHONY: format-python format-python: - @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ - --line-length=$(FORMAT_LINE_LENGTH) \ - --exclude=$(FORMAT_EXCLUDE_REGEX) - @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ - --profile=black \ - --line-length=$(FORMAT_LINE_LENGTH) \ - --skip-glob=$(FORMAT_EXCLUDE_GLOB) + @$(LOCAL_PYTHON_BIN) -m black $(LINT_ENFORCE_DIRS) + @$(LOCAL_PYTHON_BIN) -m isort $(LINT_ENFORCE_DIRS) .PHONY: lint lint: @@ -62,15 +53,8 @@ endif .PHONY: lint-python lint-python: - @$(LOCAL_PYTHON_BIN) -m black $(FORMAT_ENFORCE_DIRS) \ - --check \ - --line-length=$(FORMAT_LINE_LENGTH) \ - --exclude $(FORMAT_EXCLUDE_REGEX) - @$(LOCAL_PYTHON_BIN) -m isort $(FORMAT_ENFORCE_DIRS) \ - --check-only \ - --profile=black \ - --line-length=$(FORMAT_LINE_LENGTH) \ - --skip-glob=$(FORMAT_EXCLUDE_GLOB) + @$(LOCAL_PYTHON_BIN) -m black $(LINT_ENFORCE_DIRS) --check + @$(LOCAL_PYTHON_BIN) -m isort $(LINT_ENFORCE_DIRS) --check-only .PHONY: clean clean: diff --git a/bin/addheader.py b/bin/addheader.py index bb6efdcdd..0808c3ccb 100755 --- a/bin/addheader.py +++ b/bin/addheader.py @@ -159,14 +159,11 @@ def add_header(path, var_dict, explicit_border=True, block_wrap=False): END_MARKER, ) if block_wrap: - lic = ( - """ + lic = """ /* %s */ -""" - % lic - ) +""" % lic module_header.append(lic) # Make sure there's a blank line between license header and code diff --git a/bin/showaccounting.py b/bin/showaccounting.py index 1c5af61ef..e1b13e9fd 100755 --- a/bin/showaccounting.py +++ b/bin/showaccounting.py @@ -39,7 +39,7 @@ from mig.shared.defaults import gdp_distinguished_field -def usage(name='showaccounting.py'): +def usage(name="showaccounting.py"): """Usage help""" print("""Create accounting information based on quota. @@ -53,33 +53,31 @@ def usage(name='showaccounting.py'): -m Minimum usage Only show accounts using more than minimum usage (TB). -t TIMESTAMP Use specific timestamp, latest if unset -""" % {'name': name}) +""" % {"name": name}) -def show_accounting(configuration, - timestamp, - user_filter, - minimum_usage, - verbose): +def show_accounting( + configuration, timestamp, user_filter, minimum_usage, verbose +): """Print user accounting report""" user_filter_re = None if user_filter: try: user_filter_re = re.compile(user_filter) except Exception as err: - print("ERROR: Failed to compile user_filter: %r error: %s" - % (user_filter, err)) + print( + "ERROR: Failed to compile user_filter: %r error: %s" + % (user_filter, err) + ) return - usage = get_usage(configuration, - timestamp=timestamp, - verbose=verbose) + usage = get_usage(configuration, timestamp=timestamp, verbose=verbose) - accounting = usage.get('accounting', {}) - accounting_timestamp = usage.get('timestamp', 0) - accounting_datestr \ - = datetime.datetime.fromtimestamp(accounting_timestamp) \ - .strftime('%d/%m/%Y-%H:%M:%S') + accounting = usage.get("accounting", {}) + accounting_timestamp = usage.get("timestamp", 0) + accounting_datestr = datetime.datetime.fromtimestamp( + accounting_timestamp + ).strftime("%d/%m/%Y-%H:%M:%S") # Sorted by total bytes and print usage for users @@ -91,14 +89,19 @@ def show_accounting(configuration, for username, values in accounting.items(): # Do not show GDP project users # projects are accounted for by the main user - if configuration.site_enable_gdp \ - and username.find("/%s=" % gdp_distinguished_field) != -1: + if ( + configuration.site_enable_gdp + and username.find("/%s=" % gdp_distinguished_field) != -1 + ): continue report_total_users += 1 - total_bytes = values.get('total_bytes', 0) + total_bytes = values.get("total_bytes", 0) report_total_bytes += total_bytes - if total_bytes < minimum_usage \ - or user_filter_re and not user_filter_re.fullmatch(username): + if ( + total_bytes < minimum_usage + or user_filter_re + and not user_filter_re.fullmatch(username) + ): continue report_shown_users += 1 report_shown_bytes += total_bytes @@ -107,33 +110,36 @@ def show_accounting(configuration, total_bytes_map[total_bytes] = total_bytes_map_userlist sorted_total_bytes = sorted(list(total_bytes_map), reverse=True) - print("\nAccounting (%d) %s for storage quota(s):" - % (accounting_timestamp, accounting_datestr)) - for quota_fs, values in usage.get('quota', {}).items(): - quota_mtime = values.get('mtime', 0) - quota_datestr = datetime.datetime.fromtimestamp(quota_mtime) \ - .strftime('%d/%m/%Y-%H:%M:%S') - print(" - %s (%d) %s" % (quota_fs, - quota_mtime, - quota_datestr)) - - print("Found a total of %s users using %s storage" - % (report_total_users, - human_readable_filesize(report_total_bytes))) - print("Showing details for %s users using %s storage " - % (report_shown_users, - human_readable_filesize(report_shown_bytes))) + print( + "\nAccounting (%d) %s for storage quota(s):" + % (accounting_timestamp, accounting_datestr) + ) + for quota_fs, values in usage.get("quota", {}).items(): + quota_mtime = values.get("mtime", 0) + quota_datestr = datetime.datetime.fromtimestamp(quota_mtime).strftime( + "%d/%m/%Y-%H:%M:%S" + ) + print(" - %s (%d) %s" % (quota_fs, quota_mtime, quota_datestr)) + + print( + "Found a total of %s users using %s storage" + % (report_total_users, human_readable_filesize(report_total_bytes)) + ) + print( + "Showing details for %s users using %s storage " + % (report_shown_users, human_readable_filesize(report_shown_bytes)) + ) print("User filter: %r" % user_filter) print("Minimum usage: %s" % human_readable_filesize(minimum_usage)) for total_bytes in sorted_total_bytes: total_bytes_human = human_readable_filesize(total_bytes) for username in total_bytes_map[total_bytes]: report = accounting[username] - home_report = report.get('home_report', '') - freeze_report = report.get('freeze_report', '') - vgrid_report = report.get('vgrid_report', '') - ext_users_report = report.get('ext_users_report', '') - peers_report = report.get('peers_report', '') + home_report = report.get("home_report", "") + freeze_report = report.get("freeze_report", "") + vgrid_report = report.get("vgrid_report", "") + ext_users_report = report.get("ext_users_report", "") + peers_report = report.get("peers_report", "") print("\n%s:" % username) print("Total usage: %s" % total_bytes_human) if home_report: @@ -148,46 +154,44 @@ def show_accounting(configuration, print(peers_report) -if '__main__' == __name__: +if "__main__" == __name__: conf_path = None user_filter = None timestamp = 0 minimum_usage = 0 verbose = False - opt_args = 'hvc:f:m:t:' + opt_args = "hvc:f:m:t:" try: - (opts, args) = getopt.getopt(sys.argv[1:], opt_args) - for (opt, val) in opts: - if opt == '-h': + opts, args = getopt.getopt(sys.argv[1:], opt_args) + for opt, val in opts: + if opt == "-h": usage() sys.exit(0) - if opt == '-v': + if opt == "-v": verbose = True - elif opt == '-c': + elif opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": user_filter = val - elif opt == '-m': - minimum_usage = float(val)*(1024**4) - elif opt == '-t': + elif opt == "-m": + minimum_usage = float(val) * (1024**4) + elif opt == "-t": timestamp = int(val) else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(1) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - configuration = get_configuration_object(config_file=conf_path, - skip_log=True, - disable_auth_log=True) + configuration = get_configuration_object( + config_file=conf_path, skip_log=True, disable_auth_log=True + ) - show_accounting(configuration, - timestamp, - user_filter, - minimum_usage, - verbose) + show_accounting( + configuration, timestamp, user_filter, minimum_usage, verbose + ) sys.exit(0) diff --git a/bin/verifyvgridformat.py b/bin/verifyvgridformat.py index c1ce84aae..2289c4b01 100755 --- a/bin/verifyvgridformat.py +++ b/bin/verifyvgridformat.py @@ -49,8 +49,7 @@ def usage(name=sys.argv[0]): -v|--verbose Verbose output -c PATH|--config=PATH Path to config file -n NAME|--name=NAME Only verify specific vgrid -""" - % {"name": name}, +""" % {"name": name}, file=sys.stderr, ) @@ -78,7 +77,7 @@ def verify_vgrid_format(configuration, vgrid_name=None, verbose=False): vgrid_dirpath = os.path.join(root, dirent) owners_filepath = os.path.join(vgrid_dirpath, "owners") if os.path.isfile(owners_filepath): - vgrid = vgrid_dirpath[len(configuration.vgrid_home):].strip( + vgrid = vgrid_dirpath[len(configuration.vgrid_home) :].strip( os.sep ) vgrid_mapping[vgrid] = vgrid_flat_name(vgrid, configuration) diff --git a/mig/lib/accounting.py b/mig/lib/accounting.py index 82d3f63f9..3fce1e92f 100644 --- a/mig/lib/accounting.py +++ b/mig/lib/accounting.py @@ -41,11 +41,9 @@ from mig.shared.vgrid import vgrid_list, vgrid_list_vgrids -def __init_accounting_entry(user_bytes=0, - freeze_bytes=0, - vgrid_bytes=None, - peers=None, - ext_users=None): +def __init_accounting_entry( + user_bytes=0, freeze_bytes=0, vgrid_bytes=None, peers=None, ext_users=None +): """Return new user account dict entry""" if vgrid_bytes is None: vgrid_bytes = {} @@ -54,11 +52,13 @@ def __init_accounting_entry(user_bytes=0, if ext_users is None: ext_users = {} - return {'user_bytes': user_bytes, - 'freeze_bytes': freeze_bytes, - 'vgrid_bytes': vgrid_bytes, - 'peers': peers, - 'ext_users': ext_users} + return { + "user_bytes": user_bytes, + "freeze_bytes": freeze_bytes, + "vgrid_bytes": vgrid_bytes, + "peers": peers, + "ext_users": ext_users, + } def __get_owned_vgrid(configuration, verbose=False): @@ -67,17 +67,16 @@ def __get_owned_vgrid(configuration, verbose=False): NOTE: First owner of top-vgrid is primary owner""" logger = configuration.logger result = {} - (status, vgrids) = vgrid_list_vgrids(configuration) + status, vgrids = vgrid_list_vgrids(configuration) if status: for vgrid_name in vgrids: # print("checking vgrid: %s" % check_vgrid_name) - (owners_status, owners_list) = vgrid_list(vgrid_name, - 'owners', - configuration, - recursive=True) + owners_status, owners_list = vgrid_list( + vgrid_name, "owners", configuration, recursive=True + ) # Find first non-zero owner # NOTE: Some owner files contain empty owners) - owner = '' + owner = "" if owners_status and owners_list: owner = next(ent for ent in owners_list if ent) if owner: @@ -85,8 +84,7 @@ def __get_owned_vgrid(configuration, verbose=False): owned_vgrids.append(vgrid_name) result[owner] = owned_vgrids else: - msg = "Failed to find owner for vgrid: %s" \ - % vgrid_name + msg = "Failed to find owner for vgrid: %s" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -108,36 +106,37 @@ def __get_peers_map(configuration, verbose=False): accepted_peers = get_accepted_peers(configuration, client_id) for ext_client_id, value in accepted_peers.items(): if not isinstance(value, dict): - msg = "Invalid peers format: %s: %s: %s" \ - % (client_id, ext_client_id, value) + msg = "Invalid peers format: %s: %s: %s" % ( + client_id, + ext_client_id, + value, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue # Map external users to their peer - ext_users = peer_result.get('ext_users', {}) + ext_users = peer_result.get("ext_users", {}) ext_users[ext_client_id] = value - peer_result['ext_users'] = ext_users + peer_result["ext_users"] = ext_users # Map peers to their external user ext_result = result.get(ext_client_id, {}) - peers = ext_result.get('peers', {}) + peers = ext_result.get("peers", {}) peers[client_id] = value - ext_result['peers'] = peers + ext_result["peers"] = peers result[ext_client_id] = ext_result result[client_id] = peer_result return result -def update_accounting(configuration, - verbose=False): +def update_accounting(configuration, verbose=False): """Update user accounting information""" logger = configuration.logger retval = True - result = {'accounting': {}, - 'quota': {}} - accounting = result['accounting'] - result['timestamp'] = int(time.time()) + result = {"accounting": {}, "quota": {}} + accounting = result["accounting"] + result["timestamp"] = int(time.time()) # Map vgrid to their primary owner msg = "Creating vgrid owners map ..." @@ -184,27 +183,24 @@ def update_accounting(configuration, quota_info_json = entry.path quota_fs = entry.name.replace(".json", "") else: - logger.debug("Skipping non quota info entry: %s" - % entry.name) + logger.debug("Skipping non quota info entry: %s" % entry.name) continue quota_info = None # Try .pck first then .json if quota_info_pck: quota_info = unpickle(quota_info_pck, configuration.logger) elif quota_info_json: - quota_info = load_json(quota_info_json, - configuration.logger, - convert_utf8=False) + quota_info = load_json( + quota_info_json, configuration.logger, convert_utf8=False + ) if not quota_info: - msg = "Failed to load quota info for FS entry: %s" \ - % entry.name + msg = "Failed to load quota info for FS entry: %s" % entry.name logger.error(msg) if verbose: print("ERROR: %s" % msg) retval = False continue - quota_basepath = os.path.join(configuration.quota_home, - quota_fs) + quota_basepath = os.path.join(configuration.quota_home, quota_fs) if not os.path.isdir(quota_basepath): msg = "Missing quota_basepath: %r" % quota_basepath logger.error(msg) @@ -212,14 +208,15 @@ def update_accounting(configuration, print("ERROR: %s" % msg) retval = False continue - quota_mtime = quota_info.get('mtime', 0) - quota_datestr = datetime.datetime.fromtimestamp(quota_mtime) \ - .strftime('%d/%m/%Y-%H:%M:%S') - result['quota'][quota_fs] = {'mtime': quota_mtime} + quota_mtime = quota_info.get("mtime", 0) + quota_datestr = datetime.datetime.fromtimestamp( + quota_mtime + ).strftime("%d/%m/%Y-%H:%M:%S") + result["quota"][quota_fs] = {"mtime": quota_mtime} # User quota - user_path = os.path.join(quota_basepath, 'user') + user_path = os.path.join(quota_basepath, "user") if not os.path.isdir(user_path): msg = "Missing quota user path: %r" % user_path logger.error(msg) @@ -228,11 +225,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s user quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path) + msg = "Scanning %s user quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + ) logger.info(msg) if verbose: print(msg) @@ -241,30 +239,34 @@ def update_accounting(configuration, for user_entry in it2: if user_entry.name.endswith(".pck"): client_id = client_dir_id( - user_entry.name.replace('.pck', '')) + user_entry.name.replace(".pck", "") + ) elif user_entry.name.endswith(".json"): client_id = client_dir_id( - user_entry.name.replace('.json', '')) + user_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-user entry: %s" - % user_entry.name) + logger.debug( + "Skipping non-user entry: %s" % user_entry.name + ) continue user_quota_files[client_id] = user_entry.path t2 = time.time() - msg = "Scanned %s user quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - user_path, - (t2 - t1)) + msg = "Scanned %s user quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + user_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) # Vgrid quota - vgrid_path = os.path.join(quota_basepath, 'vgrid') + vgrid_path = os.path.join(quota_basepath, "vgrid") if not os.path.isdir(vgrid_path): msg = "Missing quota vgrid path: %r" % vgrid_path logger.error(msg) @@ -273,11 +275,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s vgrid quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path) + msg = "Scanning %s vgrid quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + ) logger.info(msg) if verbose: print(msg) @@ -286,26 +289,29 @@ def update_accounting(configuration, for vgrid_entry in it2: if vgrid_entry.name.endswith(".pck"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.pck', '')) + vgrid_entry.name.replace(".pck", "") + ) elif vgrid_entry.name.endswith(".json"): vgrid_name = force_native_str( - vgrid_entry.name.replace('.json', '')) + vgrid_entry.name.replace(".json", "") + ) else: # logger.debug("Skipping non-vgrid entry: %s" # % vgrid_entry.name) continue # NOTE: sub-vgrids uses ':' # as delimiter in 'vgrid_files_writable' - vgrid_name = vgrid_name.replace(':', '/') + vgrid_name = vgrid_name.replace(":", "/") # print("%s: %s" % (vgrid_name, vgrid_entry.path)) vgrid_quota_files[vgrid_name] = vgrid_entry.path t2 = time.time() - msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - vgrid_path, - (t2 - t1)) + msg = "Scanned %s vgrid quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + vgrid_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -313,7 +319,7 @@ def update_accounting(configuration, # Freeze quota if configuration.site_enable_freeze: - freeze_path = os.path.join(quota_basepath, 'freeze') + freeze_path = os.path.join(quota_basepath, "freeze") if not os.path.isdir(freeze_path): msg = "Missing quota freeze path: %r" % freeze_path logger.error(msg) @@ -322,11 +328,12 @@ def update_accounting(configuration, retval = False continue - msg = "Scanning %s freeze quota (%d) %s %r" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path) + msg = "Scanning %s freeze quota (%d) %s %r" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + ) logger.info(msg) if verbose: print(msg) @@ -335,23 +342,27 @@ def update_accounting(configuration, for freeze_entry in it2: if freeze_entry.name.endswith(".pck"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.pck', '')) + freeze_entry.name.replace(".pck", "") + ) elif freeze_entry.name.endswith(".json"): freeze_client_id = client_dir_id( - freeze_entry.name.replace('.json', '')) + freeze_entry.name.replace(".json", "") + ) else: - logger.debug("Skipping non-freeze entry: %s" - % freeze_entry.name) + logger.debug( + "Skipping non-freeze entry: %s" + % freeze_entry.name + ) continue - freeze_quota_files[freeze_client_id] \ - = freeze_entry.path + freeze_quota_files[freeze_client_id] = freeze_entry.path t2 = time.time() - msg = "Scanned %s freeze quota (%d) %s %r in %d secs" \ - % (quota_fs, - quota_mtime, - quota_datestr, - freeze_path, - (t2 - t1)) + msg = "Scanned %s freeze quota (%d) %s %r in %d secs" % ( + quota_fs, + quota_mtime, + quota_datestr, + freeze_path, + (t2 - t1), + ) logger.info(msg) if verbose: print(msg) @@ -361,17 +372,18 @@ def update_accounting(configuration, vgrids_accounted = [] for client_id, user_quota_filepath in user_quota_files.items(): # Init user accounting - peers = peers_map.get(client_id, {}).get('peers', {}) - ext_users = peers_map.get(client_id, {}).get('ext_users', {}) - accounting[client_id] = __init_accounting_entry(peers=peers, - ext_users=ext_users) + peers = peers_map.get(client_id, {}).get("peers", {}) + ext_users = peers_map.get(client_id, {}).get("ext_users", {}) + accounting[client_id] = __init_accounting_entry( + peers=peers, ext_users=ext_users + ) # Extract user bytes - if user_quota_filepath.endswith('.pck'): + if user_quota_filepath.endswith(".pck"): user_quota = unpickle(user_quota_filepath, configuration) - elif user_quota_filepath.endswith('.json'): - user_quota = load_json(user_quota_filepath, - configuration.logger, - convert_utf8=False) + elif user_quota_filepath.endswith(".json"): + user_quota = load_json( + user_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid user quota file: %r" % user_quota_filepath logger.error(msg) @@ -380,11 +392,13 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['user_bytes'] = user_quota['bytes'] + accounting[client_id]["user_bytes"] = user_quota["bytes"] except Exception as err: - accounting[client_id]['user_bytes'] = 0 - msg = "Failed to load user quota: %r, error: %s" \ - % (user_quota_filepath, err) + accounting[client_id]["user_bytes"] = 0 + msg = "Failed to load user quota: %r, error: %s" % ( + user_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -394,27 +408,28 @@ def update_accounting(configuration, # Extract vgrid bytes for user 'client_id' for vgrid_name in owned_vgrid.get(client_id, []): - vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, '') + vgrid_quota_filepath = vgrid_quota_files.get(vgrid_name, "") if not os.path.exists(vgrid_quota_filepath): if verbose: # NOTE: Legacy vgrids are accounted at by top-vgrid - vgrid_array = vgrid_name.split('/') - legacy_vgrid = os.path.join(configuration.vgrid_files_home, - vgrid_name) - if not os.path.isdir(legacy_vgrid) \ - or len(vgrid_array) == 1: - msg = "Missing quota for vgrid: %r" \ - % vgrid_name + vgrid_array = vgrid_name.split("/") + legacy_vgrid = os.path.join( + configuration.vgrid_files_home, vgrid_name + ) + if not os.path.isdir(legacy_vgrid) or len(vgrid_array) == 1: + msg = "Missing quota for vgrid: %r" % vgrid_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - if vgrid_quota_filepath.endswith('.pck'): + if vgrid_quota_filepath.endswith(".pck"): vgrid_quota = unpickle(vgrid_quota_filepath, configuration) - elif vgrid_quota_filepath.endswith('.json'): - vgrid_quota = load_json(vgrid_quota_filepath, - configuration.logger, - convert_utf8=False) + elif vgrid_quota_filepath.endswith(".json"): + vgrid_quota = load_json( + vgrid_quota_filepath, + configuration.logger, + convert_utf8=False, + ) else: msg = "Invalid vgrid quota file: %r" % vgrid_quota_filepath logger.error(msg) @@ -423,12 +438,15 @@ def update_accounting(configuration, retval = False continue try: - accounting[client_id]['vgrid_bytes'][vgrid_name] \ - = vgrid_quota['bytes'] + accounting[client_id]["vgrid_bytes"][vgrid_name] = vgrid_quota[ + "bytes" + ] except Exception as err: - accounting[client_id]['vgrid_bytes'][vgrid_name] = 0 - msg = "Failed to load vgrid quota: %r, error: %s" \ - % (vgrid_quota_filepath, err) + accounting[client_id]["vgrid_bytes"][vgrid_name] = 0 + msg = "Failed to load vgrid quota: %r, error: %s" % ( + vgrid_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -442,13 +460,15 @@ def update_accounting(configuration, for vgrid_name in vgrid_quota_files: if vgrid_name not in vgrids_accounted: - vgridowner = '' + vgridowner = "" for owner, owned_vgrids in owned_vgrid.items(): if vgrid_name in owned_vgrids: vgridowner = owner break - msg = "no accounting for vgrid: %r, missing owner?: %r" \ - % (vgrid_name, vgridowner) + msg = "no accounting for vgrid: %r, missing owner?: %r" % ( + vgrid_name, + vgridowner, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -457,24 +477,25 @@ def update_accounting(configuration, for freeze_name, freeze_quota_filepath in freeze_quota_files.items(): # Extract client_id from legacy freeze archive format - if freeze_name.startswith('archive-'): - legacy_freeze_meta_filepath \ - = os.path.join(configuration.freeze_home, - freeze_name, - 'meta.pck') - legacy_freeze_meta = unpickle(legacy_freeze_meta_filepath, - configuration.logger) + if freeze_name.startswith("archive-"): + legacy_freeze_meta_filepath = os.path.join( + configuration.freeze_home, freeze_name, "meta.pck" + ) + legacy_freeze_meta = unpickle( + legacy_freeze_meta_filepath, configuration.logger + ) if not legacy_freeze_meta: - msg = "Missing metadata for archive: %r" \ - % freeze_name + msg = "Missing metadata for archive: %r" % freeze_name logger.warning(msg) if verbose: print("WARNING: %s" % msg) continue - client_id = legacy_freeze_meta.get('CREATOR', '') + client_id = legacy_freeze_meta.get("CREATOR", "") if not client_id: - msg = "Failed to extract client_id from: %r" \ - % legacy_freeze_meta_filepath + msg = ( + "Failed to extract client_id from: %r" + % legacy_freeze_meta_filepath + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -486,12 +507,12 @@ def update_accounting(configuration, # Load freeze quota freeze_bytes = 0 - if freeze_quota_filepath.endswith('.pck'): + if freeze_quota_filepath.endswith(".pck"): freeze_quota = unpickle(freeze_quota_filepath, configuration) - elif freeze_quota_filepath.endswith('.json'): - freeze_quota = load_json(freeze_quota_filepath, - configuration.logger, - convert_utf8=False) + elif freeze_quota_filepath.endswith(".json"): + freeze_quota = load_json( + freeze_quota_filepath, configuration.logger, convert_utf8=False + ) else: msg = "Invalid freeze quota file: %r" % freeze_quota_filepath logger.error(msg) @@ -500,11 +521,13 @@ def update_accounting(configuration, retval = False continue try: - freeze_bytes = int(freeze_quota['bytes']) + freeze_bytes = int(freeze_quota["bytes"]) except Exception as err: freeze_bytes = 0 - msg = "Failed to fetch freeze quota: %r, error: %s" \ - % (freeze_quota_filepath, err) + msg = "Failed to fetch freeze quota: %r, error: %s" % ( + freeze_quota_filepath, + err, + ) logger.error(msg) if verbose: print("ERROR: %s" % msg) @@ -512,24 +535,27 @@ def update_accounting(configuration, continue if freeze_bytes > 0: - freeze_accounting = accounting.get(client_id, '') + freeze_accounting = accounting.get(client_id, "") if not freeze_accounting: - msg = "added missing archive user: %r : %d" \ - % (client_id, freeze_bytes) + msg = "added missing archive user: %r : %d" % ( + client_id, + freeze_bytes, + ) logger.warning(msg) if verbose: print("WARNING: %s" % msg) accounting[client_id] = __init_accounting_entry() freeze_accounting = accounting[client_id] - freeze_accounting['freeze_bytes'] += freeze_bytes + freeze_accounting["freeze_bytes"] += freeze_bytes # Save accounting result - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % result['timestamp']) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % result["timestamp"] + ) status = pickle(result, accounting_filepath, configuration.logger) if status: - latest = os.path.join(configuration.accounting_home, 'latest') + latest = os.path.join(configuration.accounting_home, "latest") status = make_symlink(accounting_filepath, latest, logger, force=True) if not status: retval = False @@ -546,33 +572,26 @@ def human_readable_filesize(filesize): return "0 B" try: p = int(math.floor(math.log(filesize, 2) / 10)) - return "%.3f %s" % (filesize / math.pow(1024, p), - ['B', - 'KiB', - 'MiB', - 'GiB', - 'TiB', - 'PiB', - 'EiB', - 'ZiB', - 'YiB'][p]) + return "%.3f %s" % ( + filesize / math.pow(1024, p), + ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"][p], + ) except (ValueError, TypeError, IndexError): - return 'NaN' + return "NaN" -def get_usage(configuration, - userlist=[], - timestamp=0, - verbose=False): +def get_usage(configuration, userlist=[], timestamp=0, verbose=False): """Generate and return 'storage' usage""" # Load accounting if it exists logger = configuration.logger if timestamp == 0: - accounting_filepath = os.path.join(configuration.accounting_home, - "latest") + accounting_filepath = os.path.join( + configuration.accounting_home, "latest" + ) else: - accounting_filepath = os.path.join(configuration.accounting_home, - "%s.pck" % timestamp) + accounting_filepath = os.path.join( + configuration.accounting_home, "%s.pck" % timestamp + ) data = unpickle(accounting_filepath, configuration.logger) if not data: msg = "Failed to load accounting data from: %r" % accounting_filepath @@ -581,7 +600,7 @@ def get_usage(configuration, print("ERROR: %s" % msg) return None - accounting = data.get('accounting', {}) + accounting = data.get("accounting", {}) # Do not show external users as main accounts unless requested # or if the user act as both peer and external user @@ -590,10 +609,13 @@ def get_usage(configuration, peer_users = [] skip_ext_users = [] for values in accounting.values(): - ext_users.extend(list(values.get('ext_users', {}))) - peer_users.extend(list(values.get('peers', {}))) - skip_ext_users = [user for user in ext_users - if user not in userlist and user not in peer_users] + ext_users.extend(list(values.get("ext_users", {}))) + peer_users.extend(list(values.get("peers", {}))) + skip_ext_users = [ + user + for user in ext_users + if user not in userlist and user not in peer_users + ] # Create accounting report @@ -610,7 +632,7 @@ def get_usage(configuration, # Home usage - home_bytes = values.get('user_bytes', 0) + home_bytes = values.get("user_bytes", 0) total_bytes += home_bytes home_report = "" if create_reports: @@ -619,7 +641,7 @@ def get_usage(configuration, # Freeze archive usage - freeze_bytes = values.get('freeze_bytes', 0) + freeze_bytes = values.get("freeze_bytes", 0) total_bytes += freeze_bytes freeze_report = "" if create_reports and freeze_bytes > 0: @@ -630,32 +652,34 @@ def get_usage(configuration, vgrid_report = "" vgrid_total = 0 - for vgrid_name, vgrid_bytes in values.get('vgrid_bytes', {}).items(): + for vgrid_name, vgrid_bytes in values.get("vgrid_bytes", {}).items(): vgrid_total += vgrid_bytes if create_reports: vgrid_bytes_human = human_readable_filesize(vgrid_bytes) - vgrid_report += "\n - %s: %s" \ - % (vgrid_name, vgrid_bytes_human) + vgrid_report += "\n - %s: %s" % (vgrid_name, vgrid_bytes_human) if vgrid_report: - vgrid_report = "%s usage (total: %s)%s" \ - % (configuration.site_vgrid_label, - human_readable_filesize(vgrid_total), - vgrid_report) + vgrid_report = "%s usage (total: %s)%s" % ( + configuration.site_vgrid_label, + human_readable_filesize(vgrid_total), + vgrid_report, + ) total_bytes += vgrid_total # Create account usage entry - account_usage[username] = {'total_bytes': total_bytes, - 'home_total': home_bytes, - 'vgrid_total': vgrid_total, - 'freeze_total': freeze_bytes, - 'ext_users_total': 0, - 'total_report': '', - 'home_report': home_report, - 'freeze_report': freeze_report, - 'vgrid_report': vgrid_report, - 'ext_users_report': '', - 'peers_report': ''} + account_usage[username] = { + "total_bytes": total_bytes, + "home_total": home_bytes, + "vgrid_total": vgrid_total, + "freeze_total": freeze_bytes, + "ext_users_total": 0, + "total_report": "", + "home_report": home_report, + "freeze_report": freeze_report, + "vgrid_report": vgrid_report, + "ext_users_report": "", + "peers_report": "", + } # Create external users report # NOTE: We need total bytes and therefore we need the above full report @@ -667,13 +691,12 @@ def get_usage(configuration, if userlist and username not in userlist: continue # Create ext_users report - ext_users = values.get('ext_users', {}) - peers = values.get('peers', {}) + ext_users = values.get("ext_users", {}) + peers = values.get("peers", {}) if not ext_users: continue if ext_users and peers: - msg = "User %r acts as both peer and external user" \ - % username + msg = "User %r acts as both peer and external user" % username logger.warning(msg) if verbose: print("WARNING: %s" % msg) @@ -683,20 +706,25 @@ def get_usage(configuration, ext_users_report = "" ext_users_total = 0 for ext_user in ext_users: - ext_user_total_bytes = account_usage.get( - ext_user, {}).get('total_bytes', 0) + ext_user_total_bytes = account_usage.get(ext_user, {}).get( + "total_bytes", 0 + ) ext_users_total += ext_user_total_bytes ext_user_total_bytes_human = human_readable_filesize( - ext_user_total_bytes) - ext_users_report += "\n - %s: %s" % (ext_user, - ext_user_total_bytes_human) + ext_user_total_bytes + ) + ext_users_report += "\n - %s: %s" % ( + ext_user, + ext_user_total_bytes_human, + ) if ext_users_report: - ext_users_report = "External users usage (total: %s):%s" \ - % (human_readable_filesize(ext_users_total), - ext_users_report) - account_usage[username]['ext_users_total'] = ext_users_total - account_usage[username]['ext_users_report'] = ext_users_report - account_usage[username]['total_bytes'] += ext_users_total + ext_users_report = "External users usage (total: %s):%s" % ( + human_readable_filesize(ext_users_total), + ext_users_report, + ) + account_usage[username]["ext_users_total"] = ext_users_total + account_usage[username]["ext_users_report"] = ext_users_report + account_usage[username]["total_bytes"] += ext_users_total # Create peers report @@ -705,22 +733,26 @@ def get_usage(configuration, peers_report += "\n - %s" % peer if peers_report: peers_report = "Accepted by the following peer:%s" % peers_report - account_usage[username]['peers_report'] = peers_report + account_usage[username]["peers_report"] = peers_report # Create total usage report for each user for usage in account_usage.values(): - usage['total_report'] = "Total usage: %s" \ - % human_readable_filesize(usage['total_bytes']) + usage["total_report"] = "Total usage: %s" % human_readable_filesize( + usage["total_bytes"] + ) # External users are accounted for by their peer # unless the external user also act as a peer result = {} - result['timestamp'] = data.get('timestamp', 0) - result['quota'] = data.get('quota', {}) - result['accounting'] = {username: values for username, values - in account_usage.items() - if not userlist or username in userlist - and username not in skip_ext_users} + result["timestamp"] = data.get("timestamp", 0) + result["quota"] = data.get("quota", {}) + result["accounting"] = { + username: values + for username, values in account_usage.items() + if not userlist + or username in userlist + and username not in skip_ext_users + } return result diff --git a/mig/lib/events.py b/mig/lib/events.py index 157933d3d..b1254e322 100644 --- a/mig/lib/events.py +++ b/mig/lib/events.py @@ -433,8 +433,7 @@ def run_cron_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) @@ -532,8 +531,7 @@ def run_events_command( _restore_env(saved_environ, os.environ) raise exc logger.info( - "(%s) done running command for %s: %s" % ( - pid, target_path, command_str) + "(%s) done running command for %s: %s" % (pid, target_path, command_str) ) # logger.debug('(%s) raw output is: %s' % (pid, output_objects)) @@ -624,5 +622,6 @@ def legacy_main(conf, print=print, _exit=sys.exit): if __name__ == "__main__": from mig.shared.conf import get_configuration_object + conf = get_configuration_object(skip_log=True, disable_auth_log=True) legacy_main(conf) diff --git a/mig/lib/janitor.py b/mig/lib/janitor.py index ea3e0051e..c18590226 100644 --- a/mig/lib/janitor.py +++ b/mig/lib/janitor.py @@ -36,8 +36,11 @@ import os import time -from mig.shared.accountreq import accept_account_req, existing_user_collision, \ - reject_account_req +from mig.shared.accountreq import ( + accept_account_req, + existing_user_collision, + reject_account_req, +) from mig.shared.base import get_user_id from mig.shared.fileio import delete_file, listdir from mig.shared.pwcrypto import verify_reset_token @@ -274,7 +277,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): _logger.info("%r made an invalid account request" % client_id) # NOTE: 'invalid' is a list of validation error strings if set reason = "invalid request: %s." % ". ".join(req_invalid) - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -284,21 +287,27 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject invalid %r account request: %s" % (client_id, - rej_err) + "failed to reject invalid %r account request: %s" + % (client_id, rej_err) ) else: _logger.info("rejected invalid %r account request" % client_id) elif authorized: - _logger.info("%r requested renew and authorized password change" % - client_id) + _logger.info( + "%r requested renew and authorized password change" % client_id + ) peer_id = user_dict.get("peers", [None])[0] # NOTE: let authorized reqs (with valid peer) renew even with pw change default_renew = True - if accept_account_req(req_id, configuration, peer_id, - user_copy=user_copy, admin_copy=admin_copy, - auth_type=auth_type, - default_renew=default_renew): + if accept_account_req( + req_id, + configuration, + peer_id, + user_copy=user_copy, + admin_copy=admin_copy, + auth_type=auth_type, + default_renew=default_renew, + ): _logger.info("accepted authorized %r access renew" % client_id) else: _logger.warning("failed authorized %r access renew" % client_id) @@ -313,7 +322,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): "%r requested and authorized password reset" % client_id ) peer_id = user_dict.get("peers", [None])[0] - (acc_status, acc_err) = accept_account_req( + acc_status, acc_err = accept_account_req( req_id, configuration, peer_id, @@ -324,18 +333,18 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not acc_status: _logger.warning( - "failed to accept %r password reset: %s" % (client_id, - acc_err) + "failed to accept %r password reset: %s" + % (client_id, acc_err) ) else: _logger.info("accepted %r password reset" % client_id) else: _logger.warning( - "%r requested password reset with bad token: %s" % ( - client_id, reset_token) + "%r requested password reset with bad token: %s" + % (client_id, reset_token) ) reason = "invalid password reset token" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -345,8 +354,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r password reset: %s" % (client_id, - rej_err) + "failed to reject %r password reset: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r password reset" % client_id) @@ -354,7 +363,7 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): # NOTE: probably should no longer happen after initial auto clean _logger.warning("%r request is now past expire" % client_id) reason = "expired request - please re-request if still relevant" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -363,15 +372,15 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to reject expired %r request: %s" % - (client_id, rej_err) - ) + _logger.warning( + "failed to reject expired %r request: %s" % (client_id, rej_err) + ) else: _logger.info("rejected %r request now past expire" % client_id) elif existing_user_collision(configuration, req_dict, client_id): _logger.warning("ID collision in request from %r" % client_id) reason = "ID collision - please re-request with *existing* ID fields" - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -381,8 +390,8 @@ def manage_single_req(configuration, req_id, req_path, db_path, now): ) if not rej_status: _logger.warning( - "failed to reject %r request with ID collision: %s" % - (client_id, rej_err) + "failed to reject %r request with ID collision: %s" + % (client_id, rej_err) ) else: _logger.info("rejected %r request with ID collision" % client_id) @@ -417,8 +426,7 @@ def manage_trivial_user_requests(configuration, now=None): continue req_id = filename req_path = os.path.join(configuration.user_pending, req_id) - _logger.debug("checking if account request in %r is trivial" % - req_path) + _logger.debug("checking if account request in %r is trivial" % req_path) req_age = now - os.path.getmtime(req_path) req_age_minutes = req_age / SECS_PER_MINUTE if req_age_minutes > MANAGE_TRIVIAL_REQ_MINUTES: @@ -428,8 +436,7 @@ def manage_trivial_user_requests(configuration, now=None): ) manage_single_req(configuration, req_id, req_path, db_path, now) handled += 1 - _logger.debug("handled %d trivial user account request action(s)" % - handled) + _logger.debug("handled %d trivial user account request action(s)" % handled) return handled @@ -474,7 +481,7 @@ def remind_and_expire_user_pending(configuration, now=None): ) user_copy = True admin_copy = True - (rej_status, rej_err) = reject_account_req( + rej_status, rej_err = reject_account_req( req_id, configuration, reason, @@ -483,11 +490,12 @@ def remind_and_expire_user_pending(configuration, now=None): auth_type=auth_type, ) if not rej_status: - _logger.warning("failed to expire %s request from %r: %s" % - (req_id, client_id, rej_err)) + _logger.warning( + "failed to expire %s request from %r: %s" + % (req_id, client_id, rej_err) + ) else: - _logger.info("expired %s request from %r" % (req_id, - client_id)) + _logger.info("expired %s request from %r" % (req_id, client_id)) handled += 1 _logger.debug("handled %d user account request action(s)" % handled) return handled diff --git a/mig/lib/lustrequota.py b/mig/lib/lustrequota.py index ea669b936..4006a7263 100644 --- a/mig/lib/lustrequota.py +++ b/mig/lib/lustrequota.py @@ -41,13 +41,24 @@ psutil = None from mig.shared.base import force_unicode -from mig.shared.fileio import make_symlink, makedirs_rec, pickle, save_json, \ - scandir, unpickle, walk, write_file +from mig.shared.fileio import ( + make_symlink, + makedirs_rec, + pickle, + save_json, + scandir, + unpickle, + walk, + write_file, +) from mig.shared.vgrid import vgrid_flat_name try: - from lustreclient.lfs import lfs_get_project_quota, lfs_set_project_id, \ - lfs_set_project_quota + from lustreclient.lfs import ( + lfs_get_project_quota, + lfs_set_project_id, + lfs_set_project_quota, + ) except ImportError: lfs_set_project_id = None lfs_get_project_quota = None @@ -63,16 +74,19 @@ def __get_lustre_basepath(configuration, lustre_basepath=None): valid_lustre_basepath = None for dpart in psutil.disk_partitions(all=True): if dpart.fstype == "lustre": - if lustre_basepath \ - and lustre_basepath.startswith(dpart.mountpoint) \ - and os.path.isdir(lustre_basepath): + if ( + lustre_basepath + and lustre_basepath.startswith(dpart.mountpoint) + and os.path.isdir(lustre_basepath) + ): valid_lustre_basepath = lustre_basepath break elif dpart.mountpoint.endswith(configuration.server_fqdn): valid_lustre_basepath = dpart.mountpoint else: - check_lustre_basepath = os.path.join(dpart.mountpoint, - configuration.server_fqdn) + check_lustre_basepath = os.path.join( + dpart.mountpoint, configuration.server_fqdn + ) if os.path.isdir(check_lustre_basepath): valid_lustre_basepath = check_lustre_basepath break @@ -85,8 +99,9 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): otherwise return default if it exists""" valid_gocryptfs_sock = None if gocryptfs_sock is None: - gocryptfs_sock = "/var/run/gocryptfs.%s.sock" \ - % configuration.server_fqdn + gocryptfs_sock = ( + "/var/run/gocryptfs.%s.sock" % configuration.server_fqdn + ) if os.path.exists(gocryptfs_sock): gocryptfs_sock_stat = os.lstat(gocryptfs_sock) if stat.S_ISSOCK(gocryptfs_sock_stat.st_mode): @@ -95,12 +110,14 @@ def __get_gocryptfs_socket(configuration, gocryptfs_sock=None): return valid_gocryptfs_sock -def __shellexec(configuration, - command, - args=[], - stdin_str=None, - stdout_filepath=None, - stderr_filepath=None): +def __shellexec( + configuration, + command, + args=[], + stdin_str=None, + stdout_filepath=None, + stderr_filepath=None, +): """Execute shell command Returns (exit_code, stdout, stderr) of subprocess""" result = 0 @@ -116,10 +133,8 @@ def __shellexec(configuration, __args.extend(args) logger.debug("__args: %s" % __args) process = subprocess.Popen( - __args, - stdin=stdin_handle, - stdout=stdout_handle, - stderr=stderr_handle) + __args, stdin=stdin_handle, stdout=stdout_handle, stderr=stderr_handle + ) if stdin_str: process.stdin.write(stdin_str.encode()) stdout, stderr = process.communicate() @@ -145,28 +160,22 @@ def __shellexec(configuration, if stderr: stderr = force_unicode(stderr) if result == 0: - logger.debug("%s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(args), - rc, - stdout, - stderr)) + logger.debug( + "%s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(args), rc, stdout, stderr) + ) else: - logger.error("shellexec: %s %s: rc: %s, stdout: %s, error: %s" - % (command, - " ".join(__args), - rc, - stdout, - stderr)) + logger.error( + "shellexec: %s %s: rc: %s, stdout: %s, error: %s" + % (command, " ".join(__args), rc, stdout, stderr) + ) return (rc, stdout, stderr) -def __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid): +def __set_project_id( + configuration, lustre_basepath, quota_datapath, quota_name, quota_lustre_pid +): """Set lustre project *quota_lustre_pid* Find the next *free* project id (PID) if *quota_lustre_pid* is occupied NOTE: lustre uses a global counter for project id's (PID) @@ -181,19 +190,22 @@ def __set_project_id(configuration, logger = configuration.logger next_lustre_pid = quota_lustre_pid while next_lustre_pid < max_lustre_pid: - (rc, currfiles, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, currfiles, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: - logger.error("Failed to fetch quota for lustre project id: %d, %r" - % (next_lustre_pid, lustre_basepath) - + ", rc: %d" % rc) + logger.error( + "Failed to fetch quota for lustre project id: %d, %r" + % (next_lustre_pid, lustre_basepath) + + ", rc: %d" % rc + ) return -1 if currfiles == 0: break - logger.info("Skipping project id: %d" - % next_lustre_pid - + " already registered with %d files" - % currfiles) + logger.info( + "Skipping project id: %d" % next_lustre_pid + + " already registered with %d files" % currfiles + ) next_lustre_pid += 1 if next_lustre_pid == max_lustre_pid: @@ -202,22 +214,28 @@ def __set_project_id(configuration, # Set new project id - logger.info("Setting lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Setting lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) rc = lfs_set_project_id(quota_datapath, next_lustre_pid, 1) if rc != 0: - logger.error("lfs_set_project_id failed for lustre project id: %d for %r: %r" - % (next_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_set_project_id failed for lustre project id: %d for %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return -1 # Dump lustre pid in quota_datapath and wait for it to appear in the quota - lustre_pid_filepath = os.path.join(quota_datapath, '.lustrepid') + lustre_pid_filepath = os.path.join(quota_datapath, ".lustrepid") status = write_file(next_lustre_pid, lustre_pid_filepath, logger) if not status: - logger.error("Failed write lustre project id: %d for %r to %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed write lustre project id: %d for %r to %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 # Wait for files to appear in quota before returning @@ -226,36 +244,44 @@ def __set_project_id(configuration, waiting = 0 max_waiting = 60 while files == 0 and waiting < max_waiting: - (rc, files, _, _, _) \ - = lfs_get_project_quota(lustre_basepath, next_lustre_pid) + rc, files, _, _, _ = lfs_get_project_quota( + lustre_basepath, next_lustre_pid + ) if rc != 0: files = 0 - logger.error("lfs_get_project_quota failed for:" - + " %d, %r, %r, rc: %d" - % (next_lustre_pid, quota_name, quota_datapath, rc)) + logger.error( + "lfs_get_project_quota failed for:" + + " %d, %r, %r, rc: %d" + % (next_lustre_pid, quota_name, quota_datapath, rc) + ) if files == 0: - logger.info("Waiting for lustre quota: %d: %r: %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.info( + "Waiting for lustre quota: %d: %r: %r" + % (next_lustre_pid, quota_name, quota_datapath) + ) time.sleep(1) max_waiting += 1 if waiting == max_waiting: - logger.error("Failed to fetch quota for:" - + " %d, %r, %r" - % (next_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to fetch quota for:" + + " %d, %r, %r" % (next_lustre_pid, quota_name, quota_datapath) + ) return -1 return next_lustre_pid -def __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_name, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp): +def __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_name, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, +): """Update quota for *quota_name*, if new entry then assign lustre project id and set default quota. If existing entry then update quota settings if changed @@ -263,15 +289,17 @@ def __update_quota(configuration, """ logger = configuration.logger quota_limits_changed = False - next_lustre_pid = lustre_setting.get('next_pid', -1) + next_lustre_pid = lustre_setting.get("next_pid", -1) if next_lustre_pid == -1: - logger.error("Invalid lustre quota next_pid: %d for: %r" - % (next_lustre_pid, quota_name)) + logger.error( + "Invalid lustre quota next_pid: %d for: %r" + % (next_lustre_pid, quota_name) + ) return False # Resolve quota limit and data basepath - if quota_type == 'vgrid': + if quota_type == "vgrid": default_quota_limit = configuration.quota_vgrid_limit data_basepath = configuration.vgrid_files_writable else: @@ -279,11 +307,13 @@ def __update_quota(configuration, data_basepath = configuration.user_home if data_basepath.startswith(configuration.state_path): - rel_data_basepath = data_basepath. \ - replace(configuration.state_path, "").lstrip(os.sep) + rel_data_basepath = data_basepath.replace( + configuration.state_path, "" + ).lstrip(os.sep) else: - logger.error("Failed to resolve relative data basepath from: %r" - % data_basepath) + logger.error( + "Failed to resolve relative data basepath from: %r" % data_basepath + ) return False # Resolve quota data path @@ -291,37 +321,38 @@ def __update_quota(configuration, if configuration.quota_backend == "lustre": quota_basefs = "lustre" - quota_datapath = os.path.join(lustre_basepath, - rel_data_basepath, - quota_name) + quota_datapath = os.path.join( + lustre_basepath, rel_data_basepath, quota_name + ) elif configuration.quota_backend == "lustre-gocryptfs": quota_basefs = "fuse.gocryptfs" stdin_str = os.path.join(rel_data_basepath, quota_name) cmd = "gocryptfs-xray -encrypt-paths %s" % gocryptfs_sock - (rc, stdout, stderr) = __shellexec(configuration, - cmd, - stdin_str=stdin_str) + rc, stdout, stderr = __shellexec( + configuration, cmd, stdin_str=stdin_str + ) if rc == 0 and stdout: encoded_path = stdout.strip() - quota_datapath = os.path.join(lustre_basepath, - encoded_path) + quota_datapath = os.path.join(lustre_basepath, encoded_path) else: - logger.error("Failed to resolve encrypted path for: %r" - % quota_name - + ", rc: %d, error: %s" - % (rc, stderr)) + logger.error( + "Failed to resolve encrypted path for: %r" % quota_name + + ", rc: %d, error: %s" % (rc, stderr) + ) return False else: - logger.error("Invalid quota backend: %r" - % configuration.quota_backend) + logger.error("Invalid quota backend: %r" % configuration.quota_backend) return False # Check if valid lustre data dir if not os.path.isdir(quota_datapath): - msg = "skipping entry: %r : %r, no lustre data path: %r" \ - % (quota_type, quota_name, quota_datapath) + msg = "skipping entry: %r : %r, no lustre data path: %r" % ( + quota_type, + quota_name, + quota_datapath, + ) # NOTE: log error and return false if dir is missing # and we expect data to be on lustre or gocryoptfs) if data_basefs == quota_basefs: @@ -335,143 +366,170 @@ def __update_quota(configuration, # Load quota if it exists otherwise new quota - quota_filepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - "%s.pck" % quota_name) + quota_filepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + "%s.pck" % quota_name, + ) if os.path.exists(quota_filepath): quota = unpickle(quota_filepath, logger) if not quota: - logger.error("Failed to load quota settings for: %r from %r" - % (quota_name, quota_filepath)) + logger.error( + "Failed to load quota settings for: %r from %r" + % (quota_name, quota_filepath) + ) return False else: - quota = {'lustre_pid': next_lustre_pid, - 'files': -1, - 'bytes': -1, - 'softlimit_bytes': -1, - 'hardlimit_bytes': -1, - } + quota = { + "lustre_pid": next_lustre_pid, + "files": -1, + "bytes": -1, + "softlimit_bytes": -1, + "hardlimit_bytes": -1, + } # Fetch quota lustre pid - quota_lustre_pid = quota.get('lustre_pid', -1) + quota_lustre_pid = quota.get("lustre_pid", -1) if quota_lustre_pid == -1: - logger.error("Invalid quota lustre pid: %d for %r" - % (quota_lustre_pid, quota_name)) + logger.error( + "Invalid quota lustre pid: %d for %r" + % (quota_lustre_pid, quota_name) + ) return False # If new entry then set lustre project id new_lustre_pid = -1 if quota_lustre_pid == next_lustre_pid: - new_lustre_pid = __set_project_id(configuration, - lustre_basepath, - quota_datapath, - quota_name, - quota_lustre_pid) + new_lustre_pid = __set_project_id( + configuration, + lustre_basepath, + quota_datapath, + quota_name, + quota_lustre_pid, + ) if new_lustre_pid == -1: - logger.error("Failed to set project id: %d, %r, %r" - % (new_lustre_pid, quota_name, quota_datapath)) + logger.error( + "Failed to set project id: %d, %r, %r" + % (new_lustre_pid, quota_name, quota_datapath) + ) return False - lustre_setting['next_pid'] = new_lustre_pid + 1 - quota['lustre_pid'] = quota_lustre_pid = new_lustre_pid + lustre_setting["next_pid"] = new_lustre_pid + 1 + quota["lustre_pid"] = quota_lustre_pid = new_lustre_pid # Get current quota values for lustre_pid - (rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes) \ - = lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + rc, currfiles, currbytes, softlimit_bytes, hardlimit_bytes = ( + lfs_get_project_quota(lustre_basepath, quota_lustre_pid) + ) if rc != 0: - logger.error("lfs_get_project_quota failed for: %d, %r, %r" - % (quota_lustre_pid, quota_name, quota_datapath) - + ", rc: %d" % rc) + logger.error( + "lfs_get_project_quota failed for: %d, %r, %r" + % (quota_lustre_pid, quota_name, quota_datapath) + + ", rc: %d" % rc + ) return False # Update quota info if currfiles == 0 or currbytes == 0: - logger.warning("lustre_basepath: %r: pid: %d: quota_type: %s" - % (lustre_basepath, quota_lustre_pid, quota_type) - + "quota_name: %s, files: %d, bytes: %d" - % (quota_name, currfiles, currbytes)) + logger.warning( + "lustre_basepath: %r: pid: %d: quota_type: %s" + % (lustre_basepath, quota_lustre_pid, quota_type) + + "quota_name: %s, files: %d, bytes: %d" + % (quota_name, currfiles, currbytes) + ) - quota['mtime'] = timestamp - quota['files'] = currfiles - quota['bytes'] = currbytes + quota["mtime"] = timestamp + quota["files"] = currfiles + quota["bytes"] = currbytes # If new entry use default quota # and update quota if changed if new_lustre_pid > -1: quota_limits_changed = True - quota['softlimit_bytes'] = default_quota_limit - quota['hardlimit_bytes'] = default_quota_limit - elif hardlimit_bytes != quota.get('hardlimit_bytes', -1) \ - or softlimit_bytes != quota.get('softlimit_bytes', -1): + quota["softlimit_bytes"] = default_quota_limit + quota["hardlimit_bytes"] = default_quota_limit + elif hardlimit_bytes != quota.get( + "hardlimit_bytes", -1 + ) or softlimit_bytes != quota.get("softlimit_bytes", -1): quota_limits_changed = True - quota['softlimit_bytes'] = softlimit_bytes - quota['hardlimit_bytes'] = hardlimit_bytes + quota["softlimit_bytes"] = softlimit_bytes + quota["hardlimit_bytes"] = hardlimit_bytes if quota_limits_changed: - rc = lfs_set_project_quota(quota_datapath, - quota_lustre_pid, - quota['softlimit_bytes'], - quota['hardlimit_bytes'], - ) + rc = lfs_set_project_quota( + quota_datapath, + quota_lustre_pid, + quota["softlimit_bytes"], + quota["hardlimit_bytes"], + ) if rc != 0: - logger.error("Failed to set quota limit: %d/%d" - % (softlimit_bytes, - hardlimit_bytes) - + " for lustre project id: %d, %r, %r, rc: %d" - % (quota_lustre_pid, - quota_name, - quota_datapath, - rc)) + logger.error( + "Failed to set quota limit: %d/%d" + % (softlimit_bytes, hardlimit_bytes) + + " for lustre project id: %d, %r, %r, rc: %d" + % (quota_lustre_pid, quota_name, quota_datapath, rc) + ) return False # Save current quota - new_quota_basepath = os.path.join(configuration.quota_home, - configuration.quota_backend, - quota_type, - str(timestamp)) - if not os.path.exists(new_quota_basepath) \ - and not makedirs_rec(new_quota_basepath, configuration): - logger.error("Failed to create new quota base path: %r" - % new_quota_basepath) + new_quota_basepath = os.path.join( + configuration.quota_home, + configuration.quota_backend, + quota_type, + str(timestamp), + ) + if not os.path.exists(new_quota_basepath) and not makedirs_rec( + new_quota_basepath, configuration + ): + logger.error( + "Failed to create new quota base path: %r" % new_quota_basepath + ) return False - new_quota_filepath_pck = os.path.join(new_quota_basepath, - "%s.pck" % quota_name) + new_quota_filepath_pck = os.path.join( + new_quota_basepath, "%s.pck" % quota_name + ) - logger.debug("Saving: %s: %s: %s -> %r" - % (quota_type, quota_name, quota, new_quota_filepath_pck)) + logger.debug( + "Saving: %s: %s: %s -> %r" + % (quota_type, quota_name, quota, new_quota_filepath_pck) + ) status = pickle(quota, new_quota_filepath_pck, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_pck)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_pck) + ) return False - new_quota_filepath_json = os.path.join(new_quota_basepath, - "%s.json" % quota_name) - status = save_json(quota, - new_quota_filepath_json, - logger) + new_quota_filepath_json = os.path.join( + new_quota_basepath, "%s.json" % quota_name + ) + status = save_json(quota, new_quota_filepath_json, logger) if not status: - logger.error("Failed to save quota for: %r to %r" - % (quota_name, new_quota_filepath_json)) + logger.error( + "Failed to save quota for: %r to %r" + % (quota_name, new_quota_filepath_json) + ) return False # Create symlink to new quota - status = make_symlink(new_quota_filepath_pck, - quota_filepath, - logger, - force=True) + status = make_symlink( + new_quota_filepath_pck, quota_filepath, logger, force=True + ) if not status: - logger.error("Failed to make quota symlink for: %r: %r -> %r" - % (quota_name, new_quota_filepath_pck, quota_filepath)) + logger.error( + "Failed to make quota symlink for: %r: %r -> %r" + % (quota_name, new_quota_filepath_pck, quota_filepath) + ) return False return True @@ -483,9 +541,11 @@ def update_lustre_quota(configuration): # Check if lustreclient module was imported correctly - if lfs_set_project_id is None \ - or lfs_get_project_quota is None \ - or lfs_set_project_quota is None: + if ( + lfs_set_project_id is None + or lfs_get_project_quota is None + or lfs_set_project_quota is None + ): logger.error("Failed to import lustreclient module") return False @@ -496,11 +556,11 @@ def update_lustre_quota(configuration): lustre_basepath = __get_lustre_basepath(configuration) if lustre_basepath: - logger.debug("Using lustre basepath: %r" - % lustre_basepath) + logger.debug("Using lustre basepath: %r" % lustre_basepath) else: - logger.error("Found no valid lustre mounts for: %s" - % configuration.server_fqdn) + logger.error( + "Found no valid lustre mounts for: %s" % configuration.server_fqdn + ) return False # Get gocryptfs socket if enabled @@ -509,37 +569,38 @@ def update_lustre_quota(configuration): if configuration.quota_backend == "lustre-gocryptfs": gocryptfs_sock = __get_gocryptfs_socket(configuration) if gocryptfs_sock: - logger.debug("Using gocryptfs socket: %r" - % gocryptfs_sock) + logger.debug("Using gocryptfs socket: %r" % gocryptfs_sock) else: logger.error("Missing gocryptfs socket") return False # Load lustre quota settings - lustre_setting_filepath = os.path.join(configuration.quota_home, - '%s.pck' - % configuration.quota_backend) + lustre_setting_filepath = os.path.join( + configuration.quota_home, "%s.pck" % configuration.quota_backend + ) if os.path.exists(lustre_setting_filepath): - lustre_setting = unpickle(lustre_setting_filepath, - logger) + lustre_setting = unpickle(lustre_setting_filepath, logger) if not lustre_setting: - logger.error("Failed to load lustre quota: %r" - % lustre_setting_filepath) + logger.error( + "Failed to load lustre quota: %r" % lustre_setting_filepath + ) return False else: - lustre_setting = {'next_pid': 1, - 'mtime': 0} + lustre_setting = {"next_pid": 1, "mtime": 0} # Update quota - quota_targets = {'vgrid': {'basefs': 'lustre', - 'entries': {}, - }, - 'user': {'basefs': 'lustre', - 'entries': {}, - }, - } + quota_targets = { + "vgrid": { + "basefs": "lustre", + "entries": {}, + }, + "user": { + "basefs": "lustre", + "entries": {}, + }, + } # Resolve basefs if possible @@ -547,26 +608,29 @@ def update_lustre_quota(configuration): mountpoint = dpart.mountpoint.rstrip(os.sep) fstype = dpart.fstype if mountpoint == configuration.vgrid_files_writable.rstrip(os.sep): - logger.debug("Found basefs for vgrid data: %r : %r" - % (mountpoint, fstype)) - quota_targets['vgrid']['basefs'] = fstype + logger.debug( + "Found basefs for vgrid data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["vgrid"]["basefs"] = fstype if mountpoint == configuration.user_home.rstrip(os.sep): - logger.debug("Found basefs for user data: %r : %r" - % (mountpoint, fstype)) - quota_targets['user']['basefs'] = fstype + logger.debug( + "Found basefs for user data: %r : %r" % (mountpoint, fstype) + ) + quota_targets["user"]["basefs"] = fstype # Resolve vgrids and sub-vgrids for root, dirs, _ in walk(configuration.vgrid_home, topdown=True): for dirent in dirs: vgrid_dirpath = os.path.join(root, dirent) - owners_filepath = os.path.join(vgrid_dirpath, 'owners') + owners_filepath = os.path.join(vgrid_dirpath, "owners") if os.path.isfile(owners_filepath): vgrid = vgrid_flat_name( - vgrid_dirpath[len(configuration.vgrid_home):], - configuration) + vgrid_dirpath[len(configuration.vgrid_home) :], + configuration, + ) logger.debug("Found vgrid: %r" % vgrid) - quota_targets['vgrid']['entries'][vgrid] = 1 + quota_targets["vgrid"]["entries"][vgrid] = 1 # Resolve users @@ -576,45 +640,46 @@ def update_lustre_quota(configuration): userhome = os.readlink(entry.path) # NOTE: Relative links are prefixed with 'user_home' if not userhome.startswith(os.sep): - userhome = os.path.join(configuration.user_home, - userhome) + userhome = os.path.join(configuration.user_home, userhome) else: userhome = entry.path if os.path.isdir(userhome): user = os.path.basename(userhome) else: - logger.debug("skipping non-userhome: %r (%r)" - % (userhome, entry.path)) + logger.debug( + "skipping non-userhome: %r (%r)" % (userhome, entry.path) + ) continue # NOTE: Multiple links might point to same user - quota_targets['user']['entries'][user] = True + quota_targets["user"]["entries"][user] = True # Update quotas for quota_type in quota_targets: target = quota_targets.get(quota_type, {}) - data_basefs = target.get('basefs', 'lustre') - quota_entries = target.get('entries', {}) + data_basefs = target.get("basefs", "lustre") + quota_entries = target.get("entries", {}) for quota_entry in quota_entries: - status = __update_quota(configuration, - lustre_basepath, - lustre_setting, - quota_entry, - quota_type, - data_basefs, - gocryptfs_sock, - timestamp) + status = __update_quota( + configuration, + lustre_basepath, + lustre_setting, + quota_entry, + quota_type, + data_basefs, + gocryptfs_sock, + timestamp, + ) if not status: retval = False # Save updated lustre quota settings - lustre_setting['mtime'] = timestamp - status = pickle(lustre_setting, - lustre_setting_filepath, - logger) + lustre_setting["mtime"] = timestamp + status = pickle(lustre_setting, lustre_setting_filepath, logger) if not status: - logger.error("Failed to save lustra quota settings: %r" - % lustre_setting_filepath) + logger.error( + "Failed to save lustra quota settings: %r" % lustre_setting_filepath + ) return retval diff --git a/mig/lib/quota.py b/mig/lib/quota.py index e93830d28..85b795f2c 100644 --- a/mig/lib/quota.py +++ b/mig/lib/quota.py @@ -30,20 +30,22 @@ from mig.lib.lustrequota import update_lustre_quota - -supported_quota_backends = ['lustre', 'lustre-gocryptfs'] +supported_quota_backends = ["lustre", "lustre-gocryptfs"] def update_quota(configuration): """Update quota for users and vgrids""" retval = False logger = configuration.logger - if configuration.quota_backend == 'lustre' \ - or configuration.quota_backend == 'lustre-gocryptfs': + if ( + configuration.quota_backend == "lustre" + or configuration.quota_backend == "lustre-gocryptfs" + ): retval = update_lustre_quota(configuration) else: - logger.error("quota_backend: %r not in supported_quota_backends: %r" - % (configuration.quota_backend, - supported_quota_backends)) + logger.error( + "quota_backend: %r not in supported_quota_backends: %r" + % (configuration.quota_backend, supported_quota_backends) + ) return retval diff --git a/mig/server/checkcloud.py b/mig/server/checkcloud.py index 15f00aa43..fa73f2dfb 100755 --- a/mig/server/checkcloud.py +++ b/mig/server/checkcloud.py @@ -27,20 +27,22 @@ """Check cloud instances allowed and running for users""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import pickle import sys +from mig.shared.cloud import ( + cloud_load_instance, + lookup_user_service_value, + status_all_cloud_instances, +) from mig.shared.defaults import keyword_all -from mig.shared.useradm import init_user_adm, search_users, default_search -from mig.shared.cloud import lookup_user_service_value, cloud_load_instance, \ - status_all_cloud_instances +from mig.shared.useradm import default_search, init_user_adm, search_users -def usage(name='checkcloud.py'): +def usage(name="checkcloud.py"): """Usage help""" print("""Check cloud access and instance status for users. @@ -52,48 +54,49 @@ def usage(name='checkcloud.py'): -h Show this help -I CERT_DN Check only for user with ID (distinguished name) -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None verbose = False user_file = None search_filter = default_search() - opt_args = 'c:d:hI:v' + opt_args = "c:d:hI:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-v': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) - uid = 'unknown' - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose) + uid = "unknown" + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose + ) services = configuration.cloud_services if not hits: print("No matching users in user DB") @@ -102,30 +105,46 @@ def usage(name='checkcloud.py'): # Reuse conf and hits as a sparse user DB for speed conf_path, db_path = configuration, dict(hits) print("Cloud status:") - for (uid, user_dict) in hits: + for uid, user_dict in hits: if verbose: print("Checking %s" % uid) for service in services: - cloud_id = service['service_name'] - cloud_title = service['service_title'] - cloud_flavor = service.get( - "service_provider_flavor", "openstack") + cloud_id = service["service_name"] + cloud_title = service["service_title"] + cloud_flavor = service.get("service_provider_flavor", "openstack") max_instances = lookup_user_service_value( - configuration, uid, service, 'service_max_user_instances') + configuration, uid, service, "service_max_user_instances" + ) max_user_instances = int(max_instances) - print('%s cloud instances allowed for %s: %d' % - (cloud_title, uid, max_user_instances)) + print( + "%s cloud instances allowed for %s: %d" + % (cloud_title, uid, max_user_instances) + ) # Load all user instances and show status - saved_instances = cloud_load_instance(configuration, uid, - cloud_id, keyword_all) - instance_fields = ['public_fqdn', 'status'] + saved_instances = cloud_load_instance( + configuration, uid, cloud_id, keyword_all + ) + instance_fields = ["public_fqdn", "status"] status_map = status_all_cloud_instances( - configuration, uid, cloud_id, cloud_flavor, - list(saved_instances), instance_fields) - for (instance_id, instance_dict) in saved_instances.items(): - instance_label = instance_dict.get('INSTANCE_LABEL', - instance_id) - print('%s cloud instance %s (%s) for %s at %s status: %s' % - (cloud_title, instance_label, instance_id, uid, - status_map[instance_id]['public_fqdn'], - status_map[instance_id]['status'])) + configuration, + uid, + cloud_id, + cloud_flavor, + list(saved_instances), + instance_fields, + ) + for instance_id, instance_dict in saved_instances.items(): + instance_label = instance_dict.get( + "INSTANCE_LABEL", instance_id + ) + print( + "%s cloud instance %s (%s) for %s at %s status: %s" + % ( + cloud_title, + instance_label, + instance_id, + uid, + status_map[instance_id]["public_fqdn"], + status_map[instance_id]["status"], + ) + ) diff --git a/mig/server/checktwofactor.py b/mig/server/checktwofactor.py index a76887f40..83fbc6b79 100755 --- a/mig/server/checktwofactor.py +++ b/mig/server/checktwofactor.py @@ -27,19 +27,22 @@ """Check twofactor activation status for users""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import pickle import sys from mig.shared.defaults import keyword_auto -from mig.shared.useradm import init_user_adm, search_users, default_search, \ - user_twofactor_status +from mig.shared.useradm import ( + default_search, + init_user_adm, + search_users, + user_twofactor_status, +) -def usage(name='checktwofactor.py'): +def usage(name="checktwofactor.py"): """Usage help""" print("""Check twofactor auth status for users. @@ -52,71 +55,74 @@ def usage(name='checktwofactor.py'): -h Show this help -I CERT_DN Check only for user with ID (distinguished name) -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None fields = keyword_auto include_project_users = False verbose = False user_file = None search_filter = default_search() - opt_args = 'c:d:ghf:I:v' + opt_args = "c:d:ghf:I:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": fields = val.split() - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-g': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-g": include_project_users = True - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) - uid = 'unknown' + uid = "unknown" errors = [] - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose) + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose + ) if not hits: print("No matching users in user DB") else: # Reuse conf and hits as a sparse user DB for speed conf_path, db_path = configuration, dict(hits) print("2FA status:") - for (uid, user_dict) in hits: - if not include_project_users and \ - uid.split('/')[-1].startswith('GDP='): + for uid, user_dict in hits: + if not include_project_users and uid.split("/")[-1].startswith( + "GDP=" + ): continue if verbose: print("Checking %s" % uid) - (_, err) = user_twofactor_status(uid, conf_path, db_path, fields, - verbose) + _, err = user_twofactor_status( + uid, conf_path, db_path, fields, verbose + ) errors += err if errors: - print('\n'.join(errors)) + print("\n".join(errors)) elif verbose: print("%s: OK" % uid) diff --git a/mig/server/chkenabled.py b/mig/server/chkenabled.py index 6c0499f70..ffd950d0e 100755 --- a/mig/server/chkenabled.py +++ b/mig/server/chkenabled.py @@ -29,8 +29,7 @@ detecting which daemons to handle and ignore in init scripts. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -46,7 +45,7 @@ from mig.shared.conf import get_configuration_object -def usage(name='chkenabled.py'): +def usage(name="chkenabled.py"): """Usage help""" print("""Lookup site_enable_FEATURE value in MiGserver.conf. @@ -57,45 +56,45 @@ def usage(name='chkenabled.py'): -f Force operations to continue past errors -h Show this help -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: +if "__main__" == __name__: args = sys.argv[1:] conf_path = None force = False verbose = False - feature = 'UNSET' - opt_args = 'c:fhv' + feature = "UNSET" + opt_args = "c:fhv" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") if len(args) == 1: feature = args[0] @@ -104,13 +103,13 @@ def usage(name='chkenabled.py'): sys.exit(1) if verbose: - print('Lookup configuration value for %s' % feature) + print("Lookup configuration value for %s" % feature) retval = 42 try: configuration = get_configuration_object(skip_log=True) enabled = getattr(configuration, "site_enable_%s" % feature) if verbose: - print('Configuration value for %s: %s' % (feature, enabled)) + print("Configuration value for %s: %s" % (feature, enabled)) if enabled: retval = 0 except Exception as err: diff --git a/mig/server/chksidroot.py b/mig/server/chksidroot.py index 371ccbda5..3b9e38772 100755 --- a/mig/server/chksidroot.py +++ b/mig/server/chksidroot.py @@ -34,8 +34,7 @@ and rewrite to fail or success depending on output. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import os import re @@ -64,44 +63,46 @@ INVALID_MARKER = "_OUT_OF_BOUNDS_" -if __name__ == '__main__': +if __name__ == "__main__": configuration = get_configuration_object() verbose = False log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] verbose = True if verbose: - print(os.environ.get('MIG_CONF', 'DEFAULT'), configuration.server_fqdn) + print(os.environ.get("MIG_CONF", "DEFAULT"), configuration.server_fqdn) # Use separate logger - logger = daemon_logger("chksidroot", configuration.user_chksidroot_log, - log_level) + logger = daemon_logger( + "chksidroot", configuration.user_chksidroot_log, log_level + ) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates register_hangup_handler(configuration) if verbose: - print('''This is simple SID chroot check helper daemon which just + print("""This is simple SID chroot check helper daemon which just prints the real path for all allowed path requests and the invalid marker for illegal ones. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') - print('Starting chksidroot helper daemon - Ctrl-C to quit') +""") + print("Starting chksidroot helper daemon - Ctrl-C to quit") # NOTE: we use sys stdin directly chksidroot_stdin = sys.stdin addr_path_pattern = re.compile( - "^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})::(/.*)$") + "^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})::(/.*)$" + ) keep_running = True if verbose: - print('Reading commands from sys stdin') + print("Reading commands from sys stdin") while keep_running: try: client_ip = "UNKNOWN" @@ -115,16 +116,17 @@ raw_path = path = match.group(2) logger.info("chksidroot from %s got path: %r" % (client_ip, path)) if not os.path.isabs(path): - logger.error("not an absolute path from %s: %r" % - (client_ip, path)) + logger.error( + "not an absolute path from %s: %r" % (client_ip, path) + ) print(INVALID_MARKER) continue # NOTE: extract sid dir before ANY expansion to avoid escape # with e.g. /PATH/TO/OWNID/../OTHERID/somefile.txt # Where sid may be share link or session link id. doc_root = configuration.webserver_home - sharelink_prefix = os.path.join(doc_root, 'share_redirect') - session_prefix = os.path.join(doc_root, 'sid_redirect') + sharelink_prefix = os.path.join(doc_root, "share_redirect") + session_prefix = os.path.join(doc_root, "sid_redirect") is_sharelink = False is_file = False # Make sure absolute but unexpanded path is inside sid dir @@ -137,8 +139,9 @@ root = session_prefix.rstrip(os.sep) + os.sep else: # Only warn to avoid excessive noise from scanners - logger.warning("got path from %s with invalid root: %r" % - (client_ip, path)) + logger.warning( + "got path from %s with invalid root: %r" % (client_ip, path) + ) print(INVALID_MARKER) continue # Extract sid name as first component after root base @@ -152,22 +155,24 @@ # outside base, which is checked later. path = os.path.abspath(path) if not path.startswith(full_prefix): - logger.error("got path from %s outside sid base: %r" % - (client_ip, path)) + logger.error( + "got path from %s outside sid base: %r" % (client_ip, path) + ) print(INVALID_MARKER) continue if is_sharelink: # Share links use Alias to map directly into sharelink_home # and with first char mapping into access mode sub-dir there. - (access_dir, _) = extract_mode_id(configuration, sid_name) - real_root = os.path.join(configuration.sharelink_home, - access_dir) + os.sep + access_dir, _ = extract_mode_id(configuration, sid_name) + real_root = ( + os.path.join(configuration.sharelink_home, access_dir) + + os.sep + ) else: # Session links are directly in webserver_home and they map # either into mig_system_files for empty jobs or into specific # user_home for real job input/output. - real_root = configuration.webserver_home.rstrip(os.sep) + \ - os.sep + real_root = configuration.webserver_home.rstrip(os.sep) + os.sep # NOTE: we cannot completely trust linked path to be safe, # so we first check full prefix on normalized path above to avoid @@ -185,22 +190,27 @@ real_target = None if not link_target or not os.path.exists(link_path): # Only warn to avoid excessive noise from scanners - logger.warning("not a valid link from %s for path %r: %r" % - (client_ip, path, link_path)) + logger.warning( + "not a valid link from %s for path %r: %r" + % (client_ip, path, link_path) + ) print(INVALID_MARKER) continue # Find default wide base root depending on target if link_target.startswith(configuration.user_home): - user_dir = link_target.replace(configuration.user_home, '') + user_dir = link_target.replace(configuration.user_home, "") user_dir = user_dir.lstrip(os.sep).split(os.sep)[0] base_path = os.path.join(configuration.user_home, user_dir) - elif not is_sharelink and \ - link_target.startswith(configuration.mig_system_files): + elif not is_sharelink and link_target.startswith( + configuration.mig_system_files + ): base_path = configuration.mig_system_files.rstrip(os.sep) else: - logger.error("unexpected link target from %s for path %r: %r" - % (client_ip, path, link_target)) + logger.error( + "unexpected link target from %s for path %r: %r" + % (client_ip, path, link_target) + ) print(INVALID_MARKER) continue @@ -209,12 +219,15 @@ is_file = not os.path.isdir(real_target) base_path = real_target else: - logger.warning("could not narrow down base root link from %s: %r" % - (client_ip, link_target)) + logger.warning( + "could not narrow down base root link from %s: %r" + % (client_ip, link_target) + ) # We manually expand sid base. - logger.debug("found target %r for link %r" % (link_target, - link_path)) + logger.debug( + "found target %r for link %r" % (link_target, link_path) + ) # Single file sharelinks use direct link to file. If so we # manually expand to direct target. Otherwise we only replace # that prefix of path to translate it to a sharelink dir path. @@ -226,18 +239,29 @@ path = path.replace(full_prefix, link_target, 1) real_path = os.path.realpath(path) - logger.info("check path from %s in base %s or chroot: %r" % - (client_ip, base_path, path)) + logger.info( + "check path from %s in base %s or chroot: %r" + % (client_ip, base_path, path) + ) # Exact match to sid dir does not make sense as we expect a file # IMPORTANT: use path and not real_path here in order to test both - if not valid_user_path(configuration, path, base_path, - allow_equal=is_file, apache_scripts=True): - logger.error("request from %s is outside sid chroot %s: %r (%r)" % - (client_ip, base_path, raw_path, real_path)) + if not valid_user_path( + configuration, + path, + base_path, + allow_equal=is_file, + apache_scripts=True, + ): + logger.error( + "request from %s is outside sid chroot %s: %r (%r)" + % (client_ip, base_path, raw_path, real_path) + ) print(INVALID_MARKER) continue - logger.info("found valid sid chroot path from %s: %r" % - (client_ip, real_path)) + logger.info( + "found valid sid chroot path from %s: %r" + % (client_ip, real_path) + ) print(real_path) # Throttle down a bit to yield @@ -249,8 +273,8 @@ logger.error("unexpected exception: %s" % exc) print(INVALID_MARKER) if verbose: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) if verbose: - print('chksidroot helper daemon shutting down') + print("chksidroot helper daemon shutting down") sys.exit(0) diff --git a/mig/server/chkuserroot.py b/mig/server/chkuserroot.py index ec5d9fdcc..36f44d99a 100755 --- a/mig/server/chkuserroot.py +++ b/mig/server/chkuserroot.py @@ -34,8 +34,7 @@ and rewrite to fail or success depending on output. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import os import re @@ -65,44 +64,46 @@ INVALID_MARKER = "_OUT_OF_BOUNDS_" -if __name__ == '__main__': +if __name__ == "__main__": configuration = get_configuration_object() verbose = False log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] verbose = True if verbose: - print(os.environ.get('MIG_CONF', 'DEFAULT'), configuration.server_fqdn) + print(os.environ.get("MIG_CONF", "DEFAULT"), configuration.server_fqdn) # Use separate logger - logger = daemon_logger("chkuserroot", configuration.user_chkuserroot_log, - log_level) + logger = daemon_logger( + "chkuserroot", configuration.user_chkuserroot_log, log_level + ) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates register_hangup_handler(configuration) if verbose: - print('''This is simple user chroot check helper daemon which just + print("""This is simple user chroot check helper daemon which just prints the real path for all allowed path requests and the invalid marker for illegal ones. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') - print('Starting chkuserroot helper daemon - Ctrl-C to quit') +""") + print("Starting chkuserroot helper daemon - Ctrl-C to quit") # NOTE: we use sys stdin directly chkuserroot_stdin = sys.stdin addr_path_pattern = re.compile( - "^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})::(/.*)$") + "^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})::(/.*)$" + ) keep_running = True if verbose: - print('Reading commands from sys stdin') + print("Reading commands from sys stdin") while keep_running: try: client_ip = "UNKNOWN" @@ -116,8 +117,9 @@ raw_path = path = match.group(2) logger.info("chkuserroot from %s got path: %r" % (client_ip, path)) if not os.path.isabs(path): - logger.error("not an absolute path from %s: %r" % - (client_ip, path)) + logger.error( + "not an absolute path from %s: %r" % (client_ip, path) + ) print(INVALID_MARKER) continue # NOTE: extract home dir before ANY expansion to avoid escape @@ -125,8 +127,9 @@ root = configuration.user_home.rstrip(os.sep) + os.sep if not path.startswith(root): # Only warn to avoid excessive noise from scanners - logger.warning("got path from %s with invalid root: %r" % - (client_ip, path)) + logger.warning( + "got path from %s with invalid root: %r" % (client_ip, path) + ) print(INVALID_MARKER) continue # Extract name of home as first component after root base @@ -141,31 +144,45 @@ # outside home, which is checked later. path = os.path.abspath(path) if not path.startswith(home_path): - logger.error("got path from %s outside user home: %r" % - (client_ip, raw_path)) + logger.error( + "got path from %s outside user home: %r" + % (client_ip, raw_path) + ) print(INVALID_MARKER) continue real_path = os.path.realpath(path) - logger.debug("check path %r in home %s or chroot" % (path, - home_path)) + logger.debug( + "check path %r in home %s or chroot" % (path, home_path) + ) # Exact match to user home does not make sense as we expect a file # IMPORTANT: use path and not real_path here in order to test both - if not valid_user_path(configuration, path, home_path, - allow_equal=False, apache_scripts=True): - logger.error("path from %s outside user chroot %s: %r (%r)" % - (client_ip, home_path, raw_path, real_path)) + if not valid_user_path( + configuration, + path, + home_path, + allow_equal=False, + apache_scripts=True, + ): + logger.error( + "path from %s outside user chroot %s: %r (%r)" + % (client_ip, home_path, raw_path, real_path) + ) print(INVALID_MARKER) continue - elif not check_account_accessible(configuration, user_id, 'https'): + elif not check_account_accessible(configuration, user_id, "https"): # Only warn to avoid excessive noise from scanners - logger.warning("path from %s in inaccessible %s account: %r (%r)" - % (client_ip, user_id, raw_path, real_path)) + logger.warning( + "path from %s in inaccessible %s account: %r (%r)" + % (client_ip, user_id, raw_path, real_path) + ) print(INVALID_MARKER) continue - logger.info("found valid user chroot path from %s: %r" % - (client_ip, real_path)) + logger.info( + "found valid user chroot path from %s: %r" + % (client_ip, real_path) + ) print(real_path) # Throttle down a bit to yield @@ -177,8 +194,8 @@ logger.error("unexpected exception: %s" % exc) print(INVALID_MARKER) if verbose: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) if verbose: - print('chkuserroot helper daemon shutting down') + print("chkuserroot helper daemon shutting down") sys.exit(0) diff --git a/mig/server/cleansessions.py b/mig/server/cleansessions.py index 6f9c9704b..8d3d4bd34 100755 --- a/mig/server/cleansessions.py +++ b/mig/server/cleansessions.py @@ -29,18 +29,19 @@ Relies on psutil to lookup established connections for comparison. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import sys from mig.shared.conf import get_configuration_object -from mig.shared.griddaemons.sessions import expire_dead_sessions, \ - expire_dead_sessions_chunked +from mig.shared.griddaemons.sessions import ( + expire_dead_sessions, + expire_dead_sessions_chunked, +) -def usage(name='cleansessions.py'): +def usage(name="cleansessions.py"): """Usage help""" print("""Clean stale sessions from griddaemons. @@ -54,55 +55,57 @@ def usage(name='cleansessions.py'): -u USERNAME Username to specifically target in session clean up where PROTO is one or more specific IO protocols or all if it is left out. Sessions of all users are cleaned unless a specific username is requested. -""" % {'name': name}) +""" % {"name": name}) -if __name__ == '__main__': +if __name__ == "__main__": args = None chunked = False force = False verbose = False username = None - opt_args = 'cfhvu:' + opt_args = "cfhvu:" try: - (opts, args) = getopt.getopt(sys.argv[1:], opt_args) + opts, args = getopt.getopt(sys.argv[1:], opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": chunked = True - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True - elif opt == '-u': + elif opt == "-u": username = val else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) - proto_list = ['davs', 'sftp', 'ftps'] + proto_list = ["davs", "sftp", "ftps"] if args: - proto_list = [proto for proto in args - if proto in proto_list] + proto_list = [proto for proto in args if proto in proto_list] configuration = get_configuration_object() if verbose: - print('Clean up stale sessions for protocol(s) %r and user %s' % - (" ".join(proto_list), username)) + print( + "Clean up stale sessions for protocol(s) %r and user %s" + % (" ".join(proto_list), username) + ) retval = 0 cleaned = [] configuration = get_configuration_object(skip_log=True) for cur_proto in proto_list: if chunked: - expired = expire_dead_sessions_chunked(configuration, cur_proto, - username) + expired = expire_dead_sessions_chunked( + configuration, cur_proto, username + ) else: expired = expire_dead_sessions(configuration, cur_proto, username) cleaned += list(expired) @@ -110,7 +113,9 @@ def usage(name='cleansessions.py'): if cleaned: if verbose: print("\n### Session Clean Summary ###") - print('Cleaned %s stale %s sessions:\n%s' % - (len(cleaned), " ".join(proto_list), '\n'.join(cleaned))) + print( + "Cleaned %s stale %s sessions:\n%s" + % (len(cleaned), " ".join(proto_list), "\n".join(cleaned)) + ) retval = len(cleaned) sys.exit(retval) diff --git a/mig/server/createresource.py b/mig/server/createresource.py index dcf19ca26..c6c541b79 100755 --- a/mig/server/createresource.py +++ b/mig/server/createresource.py @@ -29,6 +29,7 @@ # Modifications by Martin Rehr """Add MiG resource from pending request file""" + from __future__ import print_function from __future__ import absolute_import @@ -38,31 +39,31 @@ from mig.shared.resource import create_resource -def usage(name='createresource.py'): +def usage(name="createresource.py"): """Usage help""" return """Usage: %(name)s RESOURCE_FQDN OWNER_ID RESOURCE_CONFIG -The script adds .COUNTER to the resources unique id"""\ - % {'name': name} +The script adds .COUNTER to the resources unique id""" % {"name": name} # ## Main ### -if '__main__' == __name__: +if "__main__" == __name__: if not sys.argv[3:]: print(usage()) sys.exit(1) - + resource_name = sys.argv[1].strip().lower() client_id = sys.argv[2].strip() pending_file = sys.argv[3].strip() - + configuration = get_configuration_object() - (create_status, msg) = create_resource(configuration, client_id, - resource_name, pending_file) + create_status, msg = create_resource( + configuration, client_id, resource_name, pending_file + ) if create_status: - print('Resource created with ID: %s.%s' % (resource_name, msg)) + print("Resource created with ID: %s.%s" % (resource_name, msg)) else: - print('Resource creation failed: %s' % msg) + print("Resource creation failed: %s" % msg) diff --git a/mig/server/createuser.py b/mig/server/createuser.py index ef054f890..b343a2d0a 100755 --- a/mig/server/createuser.py +++ b/mig/server/createuser.py @@ -27,27 +27,33 @@ """Add or renew MiG user in user DB and in file system""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import input -from getpass import getpass import datetime import getopt import os import sys import time +from builtins import input +from getpass import getpass from mig.shared.accountstate import default_account_expire -from mig.shared.base import fill_distinguished_name, fill_user, canonical_user, \ - force_native_str_rec +from mig.shared.base import ( + canonical_user, + fill_distinguished_name, + fill_user, + force_native_str_rec, +) from mig.shared.conf import get_configuration_object -from mig.shared.defaults import valid_auth_types, keyword_auto +from mig.shared.defaults import keyword_auto, valid_auth_types from mig.shared.gdp.all import ensure_gdp_user -from mig.shared.pwcrypto import unscramble_password, scramble_password, \ - make_hash +from mig.shared.pwcrypto import ( + make_hash, + scramble_password, + unscramble_password, +) from mig.shared.serial import load -from mig.shared.useradm import init_user_adm, create_user, load_user_dict +from mig.shared.useradm import create_user, init_user_adm, load_user_dict from mig.shared.userdb import default_db_path cert_warn = """ @@ -62,7 +68,7 @@ expire_formats = ["%Y-%m-%d", "%Y-%m-%d %H:%M", "%x", "%c"] -def usage(name='createuser.py'): +def usage(name="createuser.py"): """Usage help""" print("""Create user in the MiG user database and file system. @@ -89,13 +95,13 @@ def usage(name='createuser.py'): -s SLACK_DAYS Allow peers even with account expired within SLACK_DAYS -u USER_FILE Read user information from pickle file -v Verbose output -""" % {'name': name, 'cert_warn': cert_warn}) +""" % {"name": name, "cert_warn": cert_warn}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None - auth_type = 'custom' + auth_type = "custom" expire = None force = False verbose = False @@ -111,22 +117,22 @@ def usage(name='createuser.py'): hash_password = True user_dict = {} override_fields = {} - opt_args = 'a:c:d:e:fhi:o:p:rR:s:u:v' + opt_args = "a:c:d:e:fhi:o:p:rR:s:u:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': + for opt, val in opts: + if opt == "-a": auth_type = val - elif opt == '-c': + elif opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': + elif opt == "-e": parsed = False for fmt in ["EPOCH"] + expire_formats: try: @@ -141,55 +147,55 @@ def usage(name='createuser.py'): except ValueError: pass if parsed: - override_fields['expire'] = expire - override_fields['status'] = 'temporal' + override_fields["expire"] = expire + override_fields["status"] = "temporal" else: - print('Failed to parse expire value: %s' % val) + print("Failed to parse expire value: %s" % val) sys.exit(1) - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-o': + elif opt == "-o": short_id = val - override_fields['short_id'] = short_id - elif opt == '-p': + override_fields["short_id"] = short_id + elif opt == "-p": peer_pattern = val - override_fields['peer_pattern'] = peer_pattern - override_fields['status'] = 'temporal' - elif opt == '-r': + override_fields["peer_pattern"] = peer_pattern + override_fields["status"] = "temporal" + elif opt == "-r": default_renew = True ask_renew = False - elif opt == '-R': + elif opt == "-R": role = val - override_fields['role'] = role - elif opt == '-s': + override_fields["role"] = role + elif opt == "-s": # Translate slack days into seconds as - slack_secs = int(float(val)*24*3600) - elif opt == '-u': + slack_secs = int(float(val) * 24 * 3600) + elif opt == "-u": user_file = val # NOTE: hashing should already be handled explicitly hash_password = False - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) sys.exit(1) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: if verbose: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: if verbose: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object(config_file=conf_path) logger = configuration.logger @@ -198,18 +204,20 @@ def usage(name='createuser.py'): db_path = default_db_path(configuration) if user_file and args: - print('Error: Only one kind of user specification allowed at a time') + print("Error: Only one kind of user specification allowed at a time") usage() sys.exit(1) if auth_type not in valid_auth_types: - print('Error: invalid account auth type %r requested (allowed: %s)' % - (auth_type, ', '.join(valid_auth_types))) + print( + "Error: invalid account auth type %r requested (allowed: %s)" + % (auth_type, ", ".join(valid_auth_types)) + ) usage() sys.exit(1) # NOTE: renew requires original password - if auth_type == 'cert': + if auth_type == "cert": hash_password = False if expire is None: @@ -217,22 +225,23 @@ def usage(name='createuser.py'): raw_user = {} if args: - #logger.debug('createuser called with args: %s' % args) + # logger.debug('createuser called with args: %s' % args) # logger.debug('createuser using default %s and fs %s encoding' % # (sys.getdefaultencoding(), sys.getfilesystemencoding())) try: - raw_user['full_name'] = args[0] - raw_user['organization'] = args[1] - raw_user['state'] = args[2] - raw_user['country'] = args[3] - raw_user['email'] = args[4] - raw_user['comment'] = args[5] - raw_user['password'] = args[6] + raw_user["full_name"] = args[0] + raw_user["organization"] = args[1] + raw_user["state"] = args[2] + raw_user["country"] = args[3] + raw_user["email"] = args[4] + raw_user["comment"] = args[5] + raw_user["password"] = args[6] # Always allow explicit password update on command line - raw_user['authorized'] = True + raw_user["authorized"] = True except IndexError: - print('Error: too few arguments given (expected 7 got %d)' - % len(args)) + print( + "Error: too few arguments given (expected 7 got %d)" % len(args) + ) usage() sys.exit(1) # Force user ID fields to canonical form for consistency @@ -242,59 +251,62 @@ def usage(name='createuser.py'): try: user_dict = load(user_file) except Exception as err: - print('Error in user name extraction: %s' % err) + print("Error in user name extraction: %s" % err) usage() sys.exit(1) elif default_renew and user_id: - #logger.debug('createuser called with user_id: %s' % [user_id]) + # logger.debug('createuser called with user_id: %s' % [user_id]) saved = load_user_dict(logger, user_id, db_path, verbose) if not saved: - print('Error: no such user in user db: %s' % user_id) + print("Error: no such user in user db: %s" % user_id) usage() sys.exit(1) user_dict.update(saved) - del user_dict['expire'] + del user_dict["expire"] elif not configuration.site_enable_gdp: if verbose: - print('''Entering interactive mode -%s''' % cert_warn) - print('Please enter the details for the new user:') - raw_user['full_name'] = input('Full Name: ').title() - raw_user['organization'] = input('Organization: ') - raw_user['state'] = input('State: ') - raw_user['country'] = input('2-letter Country Code: ') - raw_user['email'] = input('Email: ') - raw_user['comment'] = input('Comment: ') - raw_user['password'] = getpass('Password: ') + print("""Entering interactive mode +%s""" % cert_warn) + print("Please enter the details for the new user:") + raw_user["full_name"] = input("Full Name: ").title() + raw_user["organization"] = input("Organization: ") + raw_user["state"] = input("State: ") + raw_user["country"] = input("2-letter Country Code: ") + raw_user["email"] = input("Email: ") + raw_user["comment"] = input("Comment: ") + raw_user["password"] = getpass("Password: ") # Force user ID fields to canonical form for consistency # Title name, lowercase email, uppercase country and state, etc. user_dict = canonical_user(configuration, raw_user, raw_user.keys()) else: - print("Error: Missing one or more of the arguments: " - + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " - + "[EMAIL] [COMMENT] [PASSWORD]") + print( + "Error: Missing one or more of the arguments: " + + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " + + "[EMAIL] [COMMENT] [PASSWORD]" + ) sys.exit(1) # Encode password if set but not already encoded - if user_dict['password']: + if user_dict["password"]: if hash_password: - user_dict['password_hash'] = make_hash(user_dict['password']) - user_dict['password'] = '' + user_dict["password_hash"] = make_hash(user_dict["password"]) + user_dict["password"] = "" else: salt = configuration.site_password_salt try: - unscramble_password(salt, user_dict['password']) + unscramble_password(salt, user_dict["password"]) except TypeError: - user_dict['password'] = scramble_password( - salt, user_dict['password']) + user_dict["password"] = scramble_password( + salt, user_dict["password"] + ) if user_id: - user_dict['distinguished_name'] = user_id - elif 'distinguished_name' not in user_dict: + user_dict["distinguished_name"] = user_id + elif "distinguished_name" not in user_dict: fill_distinguished_name(user_dict) - #logger.debug('createuser with ID: %s' % [user_dict['distinguished_name']]) + # logger.debug('createuser with ID: %s' % [user_dict['distinguished_name']]) fill_user(user_dict) force_native_str_rec(user_dict) @@ -303,36 +315,48 @@ def usage(name='createuser.py'): # Make sure account expire is set with local certificate or OpenID login - if 'expire' not in user_dict: - override_fields['expire'] = expire + if "expire" not in user_dict: + override_fields["expire"] = expire # NOTE: let non-ID command line values override loaded values - for (key, val) in list(override_fields.items()): + for key, val in list(override_fields.items()): user_dict[key] = val # Now all user fields are set and we can begin adding the user if verbose: - print('using user dict: %s' % user_dict) + print("using user dict: %s" % user_dict) try: - create_user(user_dict, configuration, db_path, force, verbose, - ask_renew, default_renew, verify_peer=peer_pattern, - peer_expire_slack=slack_secs, ask_change_pw=ask_change_pw) + create_user( + user_dict, + configuration, + db_path, + force, + verbose, + ask_renew, + default_renew, + verify_peer=peer_pattern, + peer_expire_slack=slack_secs, + ask_change_pw=ask_change_pw, + ) if configuration.site_enable_gdp: - (success_here, msg) = ensure_gdp_user(configuration, - "127.0.0.1", - user_dict['distinguished_name']) + success_here, msg = ensure_gdp_user( + configuration, "127.0.0.1", user_dict["distinguished_name"] + ) if not success_here: raise Exception("Failed to ensure GDP user: %s" % msg) except Exception as exc: print("Error creating user: %s" % exc) import traceback + logger.warning("Error creating user: %s" % traceback.format_exc()) sys.exit(1) - print('Created or updated %s in user database and in file system' % - user_dict['distinguished_name']) + print( + "Created or updated %s in user database and in file system" + % user_dict["distinguished_name"] + ) if user_file: if verbose: - print('Cleaning up tmp file: %s' % user_file) + print("Cleaning up tmp file: %s" % user_file) os.remove(user_file) diff --git a/mig/server/deleteuser.py b/mig/server/deleteuser.py index 94ddea029..826aa7bbb 100755 --- a/mig/server/deleteuser.py +++ b/mig/server/deleteuser.py @@ -27,21 +27,23 @@ """Remove MiG user from user database and file system""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import input import getopt import os import sys +from builtins import input -from mig.shared.base import fill_distinguished_name, fill_user, \ - distinguished_name_to_user +from mig.shared.base import ( + distinguished_name_to_user, + fill_distinguished_name, + fill_user, +) from mig.shared.conf import get_configuration_object -from mig.shared.useradm import init_user_adm, delete_user +from mig.shared.useradm import delete_user, init_user_adm -def usage(name='deleteuser.py'): +def usage(name="deleteuser.py"): """Usage help""" print("""Delete user from MiG user database and file system. @@ -57,66 +59,67 @@ def usage(name='deleteuser.py'): -h Show this help -i CERT_DN Use CERT_DN as user ID no matter what other fields suggest -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None force = False verbose = False user_id = None user_dict = {} - opt_args = 'c:d:fhi:v' + opt_args = "c:d:fhi:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object( - config_file=conf_path, skip_log=True) + config_file=conf_path, skip_log=True + ) if user_id and args: - print('Error: Only one kind of user specification allowed at a time') + print("Error: Only one kind of user specification allowed at a time") usage() sys.exit(1) if args: - user_dict['full_name'] = args[0] + user_dict["full_name"] = args[0] try: - user_dict['organization'] = args[1] - user_dict['state'] = args[2] - user_dict['country'] = args[3] - user_dict['email'] = args[4] + user_dict["organization"] = args[1] + user_dict["state"] = args[2] + user_dict["country"] = args[3] + user_dict["email"] = args[4] except IndexError: # Ignore missing optional arguments @@ -125,19 +128,21 @@ def usage(name='deleteuser.py'): elif user_id: user_dict = distinguished_name_to_user(user_id) elif not configuration.site_enable_gdp: - print('Please enter the details for the user to be removed:') - user_dict['full_name'] = input('Full Name: ').title() - user_dict['organization'] = input('Organization: ') - user_dict['state'] = input('State: ') - user_dict['country'] = input('2-letter Country Code: ') - user_dict['email'] = input('Email: ') + print("Please enter the details for the user to be removed:") + user_dict["full_name"] = input("Full Name: ").title() + user_dict["organization"] = input("Organization: ") + user_dict["state"] = input("State: ") + user_dict["country"] = input("2-letter Country Code: ") + user_dict["email"] = input("Email: ") else: - print("Error: Missing one or more of the arguments: " - + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " - + "[EMAIL]") + print( + "Error: Missing one or more of the arguments: " + + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " + + "[EMAIL]" + ) sys.exit(1) - if 'distinguished_name' not in user_dict: + if "distinguished_name" not in user_dict: fill_distinguished_name(user_dict) fill_user(user_dict) @@ -145,11 +150,13 @@ def usage(name='deleteuser.py'): # Now all user fields are set and we can begin deleting the user if verbose: - print('Removing DB entry and dirs for user: %s' % user_dict) + print("Removing DB entry and dirs for user: %s" % user_dict) try: delete_user(user_dict, conf_path, db_path, force, verbose) except Exception as err: print(err) sys.exit(1) - print('Deleted %s from user database and from file system' - % user_dict['distinguished_name']) + print( + "Deleted %s from user database and from file system" + % user_dict["distinguished_name"] + ) diff --git a/mig/server/editgdpuser.py b/mig/server/editgdpuser.py index 2dda1b71f..2369f6f3f 100755 --- a/mig/server/editgdpuser.py +++ b/mig/server/editgdpuser.py @@ -28,20 +28,22 @@ """Edit MiG GDP user in the GDP database and all related GDP project users in the MiG user database and file system""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os import sys from mig.shared.conf import get_configuration_object -from mig.shared.gdp.all import edit_gdp_user, reset_account_roles, \ - set_account_state +from mig.shared.gdp.all import ( + edit_gdp_user, + reset_account_roles, + set_account_state, +) from mig.shared.useradm import init_user_adm -def usage(name='editgdpuser.py'): +def usage(name="editgdpuser.py"): """Usage help""" print("""Edit existing GDP user in the GDP database, @@ -60,15 +62,14 @@ def usage(name='editgdpuser.py'): -r Reset project logins -S ACCOUNT_STATE Change GDP user account state to ACCOUNT_STATE -v Verbose output -""" - % {'name': name}) +""" % {"name": name}) # ## Main ### -if '__main__' == __name__: +if "__main__" == __name__: flock = None - (args, app_dir, mig_db_path) = init_user_adm() + args, app_dir, mig_db_path = init_user_adm() gdp_db_path = None conf_path = None force = False @@ -81,61 +82,61 @@ def usage(name='editgdpuser.py'): # NOTE: Remove fields is NOT supported through 'editgdpuser', # user 'edituser' to remove fields remove_fields = [] - opt_args = 'c:g:d:fhri:S:v' + opt_args = "c:g:d:fhri:S:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": mig_db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-g': + elif opt == "-g": gdp_db_path = val - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-r': + elif opt == "-r": reset_roles = True - elif opt == '-S': + elif opt == "-S": account_state = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object(config_file=conf_path) if not user_id: - print('Error: Existing user ID is required') + print("Error: Existing user ID is required") usage() sys.exit(1) if args: try: - user_dict['full_name'] = args[0] - user_dict['organization'] = args[1] - user_dict['state'] = args[2] - user_dict['country'] = args[3] - user_dict['email'] = args[4] + user_dict["full_name"] = args[0] + user_dict["organization"] = args[1] + user_dict["state"] = args[2] + user_dict["country"] = args[3] + user_dict["email"] = args[4] except IndexError: # Ignore missing optional arguments @@ -143,9 +144,11 @@ def usage(name='editgdpuser.py'): pass elif not (account_state or reset_roles): - print("Error: Missing one or more of the arguments: " - + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " - + "[EMAIL]") + print( + "Error: Missing one or more of the arguments: " + + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " + + "[EMAIL]" + ) sys.exit(1) # Remove empty value fields @@ -155,22 +158,18 @@ def usage(name='editgdpuser.py'): del user_dict[key] if account_state: - (status, msg) = set_account_state( - configuration, - user_id, - account_state, - gdp_db_path=gdp_db_path) + status, msg = set_account_state( + configuration, user_id, account_state, gdp_db_path=gdp_db_path + ) print(msg) elif reset_roles: - (status, msg) = reset_account_roles( - configuration, - user_id, - gdp_db_path=gdp_db_path, - verbose=verbose) + status, msg = reset_account_roles( + configuration, user_id, gdp_db_path=gdp_db_path, verbose=verbose + ) else: if force: print("WARNING: -f disables rollback !!!") - (status, msg) = edit_gdp_user( + status, msg = edit_gdp_user( configuration, user_id, user_dict, @@ -179,7 +178,8 @@ def usage(name='editgdpuser.py'): mig_db_path, gdp_db_path=gdp_db_path, force=force, - verbose=verbose) + verbose=verbose, + ) if not verbose: # NOTE: If verbose everything is printed from functions in GDP if not status: diff --git a/mig/server/editmeta.py b/mig/server/editmeta.py index 45e649b4e..c1dca87e6 100755 --- a/mig/server/editmeta.py +++ b/mig/server/editmeta.py @@ -27,17 +27,16 @@ """Edit MiG user metadata in user database - only non-ID fields""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os import sys -from mig.shared.useradm import init_user_adm, edit_user +from mig.shared.useradm import edit_user, init_user_adm -def usage(name='editmeta.py'): +def usage(name="editmeta.py"): """Usage help""" print("""Edit existing user (non-ID) metadata in MiG user database. @@ -51,13 +50,13 @@ def usage(name='editmeta.py'): -r Remove provided FIELD(S) from USER_ID -h Show this help -v Verbose output -""" % {'name': name}) +""" % {"name": name}) # ## Main ### -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None force = False remove = False @@ -65,58 +64,65 @@ def usage(name='editmeta.py'): verbose = False user_id = None user_dict = {} - opt_args = 'c:d:frhv' + opt_args = "c:d:frhv" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-r': + elif opt == "-r": remove = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") if remove and len(args) > 1: - user_id = user_dict['distinguished_name'] = args[0] + user_id = user_dict["distinguished_name"] = args[0] remove_fields += args[1:] elif len(args) == 3: - user_id = user_dict['distinguished_name'] = args[0] + user_id = user_dict["distinguished_name"] = args[0] user_dict[args[1]] = args[2] else: usage() sys.exit(1) if verbose: - print('Update DB entry for %s: %s' % (user_id, user_dict)) + print("Update DB entry for %s: %s" % (user_id, user_dict)) try: - user = edit_user(user_id, user_dict, remove_fields, conf_path, db_path, - force, verbose, True) + user = edit_user( + user_id, + user_dict, + remove_fields, + conf_path, + db_path, + force, + verbose, + True, + ) except Exception as err: print(err) sys.exit(1) - print('%s\nchanged to\n%s\nin user database' % - (user_id, user)) + print("%s\nchanged to\n%s\nin user database" % (user_id, user)) diff --git a/mig/server/edituser.py b/mig/server/edituser.py index 1a8b42868..01e9252f7 100755 --- a/mig/server/edituser.py +++ b/mig/server/edituser.py @@ -27,20 +27,19 @@ """Edit MiG user in user database and file system""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import input import getopt import os import sys +from builtins import input from mig.shared.base import is_gdp_user from mig.shared.conf import get_configuration_object -from mig.shared.useradm import init_user_adm, edit_user +from mig.shared.useradm import edit_user, init_user_adm -def usage(name='edituser.py'): +def usage(name="edituser.py"): """Usage help""" print("""Edit existing user in MiG user database and file system. Allows @@ -59,14 +58,13 @@ def usage(name='edituser.py'): -r FIELDS Remove FIELDS for user in user DB -R ROLES Change user affiliation to ROLES -v Verbose output -""" - % {'name': name}) +""" % {"name": name}) # ## Main ### -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None force = False verbose = False @@ -75,52 +73,53 @@ def usage(name='edituser.py'): role = None remove_fields = [] user_dict = {} - opt_args = 'c:d:fhi:o:r:R:v' + opt_args = "c:d:fhi:o:r:R:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-o': + elif opt == "-o": short_id = val - elif opt == '-r': + elif opt == "-r": remove_fields += val.split() - elif opt == '-R': + elif opt == "-R": role = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object( - config_file=conf_path, skip_log=True) + config_file=conf_path, skip_log=True + ) if not user_id: - print('Error: Existing user ID is required') + print("Error: Existing user ID is required") usage() sys.exit(1) @@ -132,13 +131,13 @@ def usage(name='edituser.py'): if args: try: - user_dict['full_name'] = args[0] - user_dict['organization'] = args[1] - user_dict['state'] = args[2] - user_dict['country'] = args[3] - user_dict['email'] = args[4] - user_dict['comment'] = args[5] - user_dict['password'] = args[6] + user_dict["full_name"] = args[0] + user_dict["organization"] = args[1] + user_dict["state"] = args[2] + user_dict["country"] = args[3] + user_dict["email"] = args[4] + user_dict["comment"] = args[5] + user_dict["password"] = args[6] except IndexError: # Ignore missing optional arguments @@ -146,26 +145,28 @@ def usage(name='edituser.py'): pass elif not configuration.site_enable_gdp: # NOTE: We do not allow interactive user management on GDP systems - print('Please enter the new details for %s:' % user_id) - print('[enter to skip field]') - user_dict['full_name'] = input('Full Name: ').title() - user_dict['organization'] = input('Organization: ') - user_dict['state'] = input('State: ') - user_dict['country'] = input('2-letter Country Code: ') - user_dict['email'] = input('Email: ') + print("Please enter the new details for %s:" % user_id) + print("[enter to skip field]") + user_dict["full_name"] = input("Full Name: ").title() + user_dict["organization"] = input("Organization: ") + user_dict["state"] = input("State: ") + user_dict["country"] = input("2-letter Country Code: ") + user_dict["email"] = input("Email: ") else: - print("Error: Missing one or more of the arguments: " - + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " - + "[EMAIL] [COMMENT] [PASSWORD]") + print( + "Error: Missing one or more of the arguments: " + + "[FULL_NAME] [ORGANIZATION] [STATE] [COUNTRY] " + + "[EMAIL] [COMMENT] [PASSWORD]" + ) sys.exit(1) # Pass optional short_id as well if short_id: - user_dict['short_id'] = short_id + user_dict["short_id"] = short_id # Pass optional role as well if role: - user_dict['role'] = role + user_dict["role"] = role # Remove empty value fields # NOTE: force list copy here as we delete inline below @@ -174,14 +175,23 @@ def usage(name='edituser.py'): del user_dict[key] if verbose: - print('Update DB entry and dirs for %s: %s' % (user_id, user_dict)) + print("Update DB entry and dirs for %s: %s" % (user_id, user_dict)) try: - user = edit_user(user_id, user_dict, remove_fields, conf_path, db_path, force, - verbose) + user = edit_user( + user_id, + user_dict, + remove_fields, + conf_path, + db_path, + force, + verbose, + ) except Exception as err: print(err) sys.exit(1) - print('%s\nchanged to\n%s\nin user database and file system' % - (user_id, user['distinguished_name'])) + print( + "%s\nchanged to\n%s\nin user database and file system" + % (user_id, user["distinguished_name"]) + ) print() - print('Please revoke/reissue any related certificates!') + print("Please revoke/reissue any related certificates!") diff --git a/mig/server/genoiddiscovery.py b/mig/server/genoiddiscovery.py index d034ac012..224012383 100755 --- a/mig/server/genoiddiscovery.py +++ b/mig/server/genoiddiscovery.py @@ -30,8 +30,7 @@ helper function for details. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -41,7 +40,7 @@ from mig.shared.httpsclient import generate_openid_discovery_doc -def usage(name='genoiddiscovery.py'): +def usage(name="genoiddiscovery.py"): """Usage help""" print("""Generate OpenID 2.0 discovery information for this site. @@ -52,47 +51,47 @@ def usage(name='genoiddiscovery.py'): -f Force operations to continue past errors -h Show this help -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: +if "__main__" == __name__: args = sys.argv[1:] conf_path = None force = False verbose = False - opt_args = 'c:fhv' + opt_args = "c:fhv" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") if args: - print('Got unexpected non-option arguments!') + print("Got unexpected non-option arguments!") usage() sys.exit(1) diff --git a/mig/server/grid_cron.py b/mig/server/grid_cron.py index e903e0c1b..5acbd9d2c 100755 --- a/mig/server/grid_cron.py +++ b/mig/server/grid_cron.py @@ -32,8 +32,7 @@ Requires watchdog module (https://pypi.python.org/pypi/watchdog). """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import datetime import logging @@ -45,20 +44,36 @@ import time try: + from watchdog.events import ( + DirCreatedEvent, + FileCreatedEvent, + FileModifiedEvent, + PatternMatchingEventHandler, + ) from watchdog.observers import Observer - from watchdog.events import PatternMatchingEventHandler, \ - FileModifiedEvent, FileCreatedEvent, DirCreatedEvent except ImportError: - print('ERROR: the python watchdog module is required for this daemon') + print("ERROR: the python watchdog module is required for this daemon") sys.exit(1) from mig.lib.daemon import check_stop, register_stop_handler, stop_running -from mig.lib.events import get_time_expand_map, parse_crontab, cron_match, \ - parse_atjobs, at_remain, run_cron_command -from mig.shared.base import force_utf8, client_dir_id, client_id_dir +from mig.lib.events import ( + at_remain, + cron_match, + get_time_expand_map, + parse_atjobs, + parse_crontab, + run_cron_command, +) +from mig.shared.base import client_dir_id, client_id_dir, force_utf8 from mig.shared.conf import get_configuration_object -from mig.shared.defaults import crontab_name, atjobs_name, cron_output_dir, \ - cron_log_name, cron_log_size, cron_log_cnt +from mig.shared.defaults import ( + atjobs_name, + cron_log_cnt, + cron_log_name, + cron_log_size, + cron_output_dir, + crontab_name, +) from mig.shared.fileio import scandir, walk from mig.shared.logger import daemon_logger, register_hangup_handler @@ -71,16 +86,15 @@ # Global state helpers used in a number of functions and methods shared_state = {} -shared_state['base_dir'] = None -shared_state['base_dir_len'] = 0 -shared_state['crontab_inotify'] = None -shared_state['crontab_handler'] = None +shared_state["base_dir"] = None +shared_state["base_dir_len"] = 0 +shared_state["crontab_inotify"] = None +shared_state["crontab_handler"] = None -(configuration, logger) = (None, None) +configuration, logger = (None, None) class MiGCrontabEventHandler(PatternMatchingEventHandler): - """Crontab pattern-matching event handler to take care of crontab/atjobs changes and update the global crontab database. """ @@ -95,9 +109,12 @@ def __init__( """Constructor""" PatternMatchingEventHandler.__init__( - self, patterns=patterns, ignore_patterns=ignore_patterns, + self, + patterns=patterns, + ignore_patterns=ignore_patterns, ignore_directories=ignore_directories, - case_sensitive=case_sensitive) + case_sensitive=case_sensitive, + ) def __update_crontab_monitor( self, @@ -108,25 +125,28 @@ def __update_crontab_monitor( pid = multiprocessing.current_process().pid - if state == 'created': + if state == "created": # logger.debug('(%s) Updating crontab monitor for src_path: %s, event: %s' # % (pid, src_path, state)) - print('(%s) Updating crontab monitor for src_path: %s, event: %s' - % (pid, src_path, state)) + print( + "(%s) Updating crontab monitor for src_path: %s, event: %s" + % (pid, src_path, state) + ) if os.path.exists(src_path): # _crontab_monitor_lock.acquire() - if src_path not in shared_state['crontab_inotify']._wd_for_path: + if src_path not in shared_state["crontab_inotify"]._wd_for_path: # logger.debug('(%s) Adding watch for: %s' % (pid, # src_path)) - shared_state['crontab_inotify'].add_watch( - force_utf8(src_path)) + shared_state["crontab_inotify"].add_watch( + force_utf8(src_path) + ) # Fire 'modified' events for all dirs and files in subpath # to ensure that all crontab files are loaded @@ -137,23 +157,25 @@ def __update_crontab_monitor( # logger.debug('(%s) Dispatch DirCreatedEvent for: %s' # % (pid, ent.path)) - shared_state['crontab_handler'].dispatch( - DirCreatedEvent(ent.path)) - elif ent.path.find(configuration.user_settings) \ - > -1: + shared_state["crontab_handler"].dispatch( + DirCreatedEvent(ent.path) + ) + elif ent.path.find(configuration.user_settings) > -1: # logger.debug('(%s) Dispatch FileCreatedEvent for: %s' # % (pid, ent.path)) - shared_state['crontab_handler'].dispatch( - FileCreatedEvent(ent.path)) + shared_state["crontab_handler"].dispatch( + FileCreatedEvent(ent.path) + ) # else: # logger.debug('(%s) crontab_monitor watch already exists for: %s' # % (pid, src_path)) else: - logger.debug('(%s) unhandled event: %s for: %s' % (pid, - state, src_path)) + logger.debug( + "(%s) unhandled event: %s for: %s" % (pid, state, src_path) + ) def update_crontabs(self, event): """Handle all crontab updates""" @@ -165,52 +187,59 @@ def update_crontabs(self, event): if event.is_directory: self.__update_crontab_monitor(configuration, src_path, state) elif os.path.basename(src_path) == crontab_name: - logger.debug('(%s) %s -> Updating crontab for: %s' % (pid, - state, src_path)) - rel_path = src_path[len(configuration.user_settings):] + logger.debug( + "(%s) %s -> Updating crontab for: %s" % (pid, state, src_path) + ) + rel_path = src_path[len(configuration.user_settings) :] client_dir = os.path.basename(os.path.dirname(src_path)) client_id = client_dir_id(client_dir) user_home = os.path.join(configuration.user_home, client_dir) - logger.info('(%s) refresh %s crontab from %s' % (pid, - client_id, src_path)) - if state == 'deleted': + logger.info( + "(%s) refresh %s crontab from %s" % (pid, client_id, src_path) + ) + if state == "deleted": cur_crontab = [] - logger.debug("(%s) deleted crontab from '%s'" % - (pid, src_path)) + logger.debug("(%s) deleted crontab from '%s'" % (pid, src_path)) else: cur_crontab = parse_crontab(configuration, client_id, src_path) - logger.debug("(%s) loaded new crontab from '%s':\n%s" % - (pid, src_path, cur_crontab)) + logger.debug( + "(%s) loaded new crontab from '%s':\n%s" + % (pid, src_path, cur_crontab) + ) # Replace crontabs for this user all_crontabs[src_path] = cur_crontab - logger.debug('(%s) all crontabs: %s' % (pid, all_crontabs)) + logger.debug("(%s) all crontabs: %s" % (pid, all_crontabs)) elif os.path.basename(src_path) == atjobs_name: - logger.debug('(%s) %s -> Updating atjobs for: %s' % (pid, - state, src_path)) - rel_path = src_path[len(configuration.user_settings):] + logger.debug( + "(%s) %s -> Updating atjobs for: %s" % (pid, state, src_path) + ) + rel_path = src_path[len(configuration.user_settings) :] client_dir = os.path.basename(os.path.dirname(src_path)) client_id = client_dir_id(client_dir) user_home = os.path.join(configuration.user_home, client_dir) - logger.info('(%s) refresh %s atjobs from %s' % (pid, - client_id, src_path)) - if state == 'deleted': + logger.info( + "(%s) refresh %s atjobs from %s" % (pid, client_id, src_path) + ) + if state == "deleted": cur_atjobs = [] - logger.debug("(%s) deleted atjobs from '%s'" % - (pid, src_path)) + logger.debug("(%s) deleted atjobs from '%s'" % (pid, src_path)) else: cur_atjobs = parse_atjobs(configuration, client_id, src_path) - logger.debug("(%s) loaded new atjobs from '%s':\n%s" % - (pid, src_path, cur_atjobs)) + logger.debug( + "(%s) loaded new atjobs from '%s':\n%s" + % (pid, src_path, cur_atjobs) + ) # Replace atjobs for this user all_atjobs[src_path] = cur_atjobs - logger.debug('(%s) all atjobs: %s' % (pid, all_atjobs)) + logger.debug("(%s) all atjobs: %s" % (pid, all_atjobs)) else: - logger.debug('(%s) %s skipping non-cron file: %s' % (pid, - state, src_path)) + logger.debug( + "(%s) %s skipping non-cron file: %s" % (pid, state, src_path) + ) def on_modified(self, event): """Handle modified crontab file""" @@ -232,24 +261,26 @@ def __cron_log(configuration, client_id, msg, level="info"): """Wrapper to send a single msg to user cron log file""" client_dir = client_id_dir(client_id) - log_dir_path = os.path.join(configuration.user_home, client_dir, - cron_output_dir) + log_dir_path = os.path.join( + configuration.user_home, client_dir, cron_output_dir + ) log_path = os.path.join(log_dir_path, cron_log_name) if not os.path.exists(log_dir_path): try: os.makedirs(log_dir_path) except: pass - cron_logger = logging.getLogger('cron') + cron_logger = logging.getLogger("cron") cron_logger.setLevel(logging.INFO) handler = logging.handlers.RotatingFileHandler( - log_path, maxBytes=cron_log_size, backupCount=cron_log_cnt - 1) - formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + log_path, maxBytes=cron_log_size, backupCount=cron_log_cnt - 1 + ) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") handler.setFormatter(formatter) cron_logger.addHandler(handler) - if level == 'error': + if level == "error": cron_logger.error(msg) - elif level == 'warning': + elif level == "warning": cron_logger.warning(msg) else: cron_logger.info(msg) @@ -261,60 +292,76 @@ def __cron_log(configuration, client_id, msg, level="info"): def __cron_err(configuration, client_id, msg): """Wrapper to send a single error msg to client_id cron log""" - __cron_log(configuration, client_id, msg, 'error') + __cron_log(configuration, client_id, msg, "error") def __cron_warn(configuration, client_id, msg): """Wrapper to send a single warning msg to client_id cron log""" - __cron_log(configuration, client_id, msg, 'warning') + __cron_log(configuration, client_id, msg, "warning") def __cron_info(configuration, client_id, msg): """Wrapper to send a single info msg to client_id cron log""" - __cron_log(configuration, client_id, msg, 'info') + __cron_log(configuration, client_id, msg, "info") def __handle_cronjob(configuration, client_id, timestamp, crontab_entry): """Actually handle valid crontab entry which is due""" pid = multiprocessing.current_process().pid - logger.info('(%s) in handling of %s for %s' % (pid, - crontab_entry['command'], - client_id)) - __cron_info(configuration, client_id, 'handle %s for %s' % - (crontab_entry['command'], client_id)) - - if crontab_entry['run_as'] != client_id: - logger.error('(%s) skipping due to owner mismatch for %s and %s!' % - (pid, client_id, crontab_entry)) + logger.info( + "(%s) in handling of %s for %s" + % (pid, crontab_entry["command"], client_id) + ) + __cron_info( + configuration, + client_id, + "handle %s for %s" % (crontab_entry["command"], client_id), + ) + + if crontab_entry["run_as"] != client_id: + logger.error( + "(%s) skipping due to owner mismatch for %s and %s!" + % (pid, client_id, crontab_entry) + ) return False # Expand dynamic time variables in argument once and for all expand_map = get_time_expand_map(timestamp, crontab_entry) - command_list = crontab_entry['command'][:1] - for argument in crontab_entry['command'][1:]: + command_list = crontab_entry["command"][:1] + for argument in crontab_entry["command"][1:]: filled_argument = argument - for (key, val) in expand_map.items(): + for key, val in expand_map.items(): filled_argument = filled_argument.replace(key, val) - __cron_info(configuration, client_id, - 'expanded argument %s to %s' % - (argument, filled_argument)) + __cron_info( + configuration, + client_id, + "expanded argument %s to %s" % (argument, filled_argument), + ) command_list.append(filled_argument) try: run_cron_command(command_list, client_id, crontab_entry, configuration) - logger.info('(%s) done running command for %s: %s' % - (pid, client_id, ' '.join(command_list))) - __cron_info(configuration, client_id, - 'ran command: %s' % ' '.join(command_list)) + logger.info( + "(%s) done running command for %s: %s" + % (pid, client_id, " ".join(command_list)) + ) + __cron_info( + configuration, client_id, "ran command: %s" % " ".join(command_list) + ) except Exception as exc: - command_str = ' '.join(command_list) - logger.error('(%s) failed to run command for %s: %s (%s)' % - (pid, client_id, command_str, exc)) - __cron_err(configuration, client_id, - 'failed to run command: %s (%s)' % (command_str, exc)) + command_str = " ".join(command_list) + logger.error( + "(%s) failed to run command for %s: %s (%s)" + % (pid, client_id, command_str, exc) + ) + __cron_err( + configuration, + client_id, + "failed to run command: %s (%s)" % (command_str, exc), + ) def run_handler(configuration, client_id, timestamp, crontab_entry): @@ -329,10 +376,10 @@ def run_handler(configuration, client_id, timestamp, crontab_entry): waiting_for_worker_resources = True while waiting_for_worker_resources: try: - worker = \ - multiprocessing.Process(target=__handle_cronjob, - args=(configuration, client_id, - timestamp, crontab_entry)) + worker = multiprocessing.Process( + target=__handle_cronjob, + args=(configuration, client_id, timestamp, crontab_entry), + ) worker.daemon = True worker.start() waiting_for_worker_resources = False @@ -349,44 +396,48 @@ def monitor(configuration): pid = multiprocessing.current_process().pid - print('Starting global crontab monitor process') - logger.info('Starting global crontab monitor process') + print("Starting global crontab monitor process") + logger.info("Starting global crontab monitor process") # Set base_dir and base_dir_len - shared_state['base_dir'] = os.path.join(configuration.user_settings) - shared_state['base_dir_len'] = len(shared_state['base_dir']) + shared_state["base_dir"] = os.path.join(configuration.user_settings) + shared_state["base_dir_len"] = len(shared_state["base_dir"]) # Allow e.g. logrotate to force log re-open after rotates register_hangup_handler(configuration) # Monitor crontab configurations - crontab_monitor_home = shared_state['base_dir'] + crontab_monitor_home = shared_state["base_dir"] recursive_crontab_monitor = True crontab_monitor = Observer() - crontab_pattern = os.path.join(crontab_monitor_home, '*', crontab_name) - atjobs_pattern = os.path.join(crontab_monitor_home, '*', atjobs_name) - shared_state['crontab_handler'] = MiGCrontabEventHandler( - patterns=[crontab_pattern, atjobs_pattern], ignore_directories=False, - case_sensitive=True) - - crontab_monitor.schedule(shared_state['crontab_handler'], - configuration.user_settings, - recursive=recursive_crontab_monitor) + crontab_pattern = os.path.join(crontab_monitor_home, "*", crontab_name) + atjobs_pattern = os.path.join(crontab_monitor_home, "*", atjobs_name) + shared_state["crontab_handler"] = MiGCrontabEventHandler( + patterns=[crontab_pattern, atjobs_pattern], + ignore_directories=False, + case_sensitive=True, + ) + + crontab_monitor.schedule( + shared_state["crontab_handler"], + configuration.user_settings, + recursive=recursive_crontab_monitor, + ) crontab_monitor.start() if len(crontab_monitor._emitters) != 1: - logger.error('(%s) Number of crontab_monitor._emitters != 1' % pid) + logger.error("(%s) Number of crontab_monitor._emitters != 1" % pid) return 1 crontab_monitor_emitter = min(crontab_monitor._emitters) - if not hasattr(crontab_monitor_emitter, '_inotify'): - logger.error('(%s) crontab_monitor_emitter require inotify' % pid) + if not hasattr(crontab_monitor_emitter, "_inotify"): + logger.error("(%s) crontab_monitor_emitter require inotify" % pid) return 1 - shared_state['crontab_inotify'] = crontab_monitor_emitter._inotify._inotify + shared_state["crontab_inotify"] = crontab_monitor_emitter._inotify._inotify - logger.info('(%s) trigger crontab and atjobs refresh' % (pid, )) + logger.info("(%s) trigger crontab and atjobs refresh" % (pid,)) # Fake touch event on all crontab files to load initial crontabs @@ -398,7 +449,7 @@ def monitor(configuration): all_crontab_files, all_atjobs_files = [], [] - for (root, _, files) in walk(crontab_monitor_home): + for root, _, files in walk(crontab_monitor_home): if crontab_name in files: crontab_path = os.path.join(root, crontab_name) all_crontab_files.append(crontab_path) @@ -408,11 +459,11 @@ def monitor(configuration): for target_path in all_crontab_files + all_atjobs_files: - logger.debug('(%s) trigger load on cron/at file in %s' % - (pid, target_path)) + logger.debug( + "(%s) trigger load on cron/at file in %s" % (pid, target_path) + ) - shared_state['crontab_handler'].dispatch( - FileModifiedEvent(target_path)) + shared_state["crontab_handler"].dispatch(FileModifiedEvent(target_path)) # logger.debug('(%s) loaded initial crontabs:\n%s' % (pid, # all_crontab_files)) @@ -421,81 +472,88 @@ def monitor(configuration): try: loop_start = datetime.datetime.now() loop_minute = loop_start.replace(second=0, microsecond=0) - logger.debug('main loop started with %d crontabs and %d atjobs' % - (len(all_crontabs), len(all_atjobs))) - for (crontab_path, user_crontab) in all_crontabs.items(): + logger.debug( + "main loop started with %d crontabs and %d atjobs" + % (len(all_crontabs), len(all_atjobs)) + ) + for crontab_path, user_crontab in all_crontabs.items(): client_dir = os.path.basename(os.path.dirname(crontab_path)) client_id = client_dir_id(client_dir) for entry in user_crontab: - logger.debug('inspect cron entry for %s: %s' % - (client_id, entry)) + logger.debug( + "inspect cron entry for %s: %s" % (client_id, entry) + ) if cron_match(configuration, loop_minute, entry): - logger.info('run matching cron entry: %s' % entry) - run_handler(configuration, client_id, loop_minute, - entry) + logger.info("run matching cron entry: %s" % entry) + run_handler( + configuration, client_id, loop_minute, entry + ) # NOTE: we need a copy of all_atjobs to avoid errors on inline edit - for (atjobs_path, user_atjobs) in list(all_atjobs.items()): + for atjobs_path, user_atjobs in list(all_atjobs.items()): client_dir = os.path.basename(os.path.dirname(atjobs_path)) client_id = client_dir_id(client_dir) remaining = [] for entry in user_atjobs: - logger.debug('inspect atjobs entry for %s: %s' % - (client_id, entry)) + logger.debug( + "inspect atjobs entry for %s: %s" % (client_id, entry) + ) remain_mins = at_remain(configuration, loop_minute, entry) if remain_mins == 0: - logger.info('run matching at entry: %s' % entry) - run_handler(configuration, client_id, loop_minute, - entry) + logger.info("run matching at entry: %s" % entry) + run_handler( + configuration, client_id, loop_minute, entry + ) elif remain_mins > 0: remaining.append(entry) else: - logger.info('removing expired at job: %s' % entry) + logger.info("removing expired at job: %s" % entry) # Update remaining jobs to clean up expired if remaining: all_atjobs[atjobs_path] = remaining else: del all_atjobs[atjobs_path] except KeyboardInterrupt: - print('(%s) caught interrupt' % pid) + print("(%s) caught interrupt" % pid) stop_running() except Exception as exc: - logger.error('unexpected exception in monitor: %s' % exc) + logger.error("unexpected exception in monitor: %s" % exc) import traceback + print(traceback.format_exc()) # Throttle down until next minute loop_time = (datetime.datetime.now() - loop_start).seconds if loop_time > 60: - logger.warning('(%s) loop did not finish before next tick: %s' % - (os.getpid(), loop_time)) + logger.warning( + "(%s) loop did not finish before next tick: %s" + % (os.getpid(), loop_time) + ) loop_time = 59 # Target sleep until start of next minute sleep_time = max(60 - (loop_time + loop_start.second), 1) # TODO: this debug log never shows up - conflict with user info log? # at least it does if changed to info. - logger.debug('main loop sleeping %ds' % sleep_time) + logger.debug("main loop sleeping %ds" % sleep_time) # print('main loop sleeping %ds' % sleep_time) time.sleep(sleep_time) - print('(%s) Exiting crontab monitor' % pid) - logger.info('(%s) Exiting crontab monitor' % pid) + print("(%s) Exiting crontab monitor" % pid) + logger.info("(%s) Exiting crontab monitor" % pid) return 0 -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', - 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger('cron', configuration.user_cron_log, - log_level) + logger = daemon_logger("cron", configuration.user_cron_log, log_level) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -510,24 +568,25 @@ def monitor(configuration): print(err_msg) sys.exit(1) - print('''This is the MiG cron handler daemon which monitors user crontab + print("""This is the MiG cron handler daemon which monitors user crontab files and reacts to any configured actions when time is up. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') +""") main_pid = os.getpid() - print('Starting Cron handler daemon - Ctrl-C to quit') - logger.info('(%s) Starting Cron handler daemon' % main_pid) + print("Starting Cron handler daemon - Ctrl-C to quit") + logger.info("(%s) Starting Cron handler daemon" % main_pid) # Start a single global monitor for all crontabs - crontab_monitor = multiprocessing.Process(target=monitor, - args=(configuration, )) + crontab_monitor = multiprocessing.Process( + target=monitor, args=(configuration,) + ) crontab_monitor.start() - logger.debug('(%s) Starting main loop' % main_pid) + logger.debug("(%s) Starting main loop" % main_pid) print("%s: Start main loop" % os.getpid()) while not check_stop(): try: @@ -537,27 +596,29 @@ def monitor(configuration): # NOTE: we can't be sure if SIGINT was sent to only main process # so we make sure to propagate to monitor child print("Interrupt requested - close monitor and shutdown") - logger.info('(%s) Shut down monitor and wait' % os.getpid()) + logger.info("(%s) Shut down monitor and wait" % os.getpid()) mon_pid = crontab_monitor.pid if mon_pid is not None: - logger.debug('send exit signal to monitor %s' % mon_pid) + logger.debug("send exit signal to monitor %s" % mon_pid) os.kill(mon_pid, signal.SIGINT) break except Exception as exc: - logger.error('(%s) Caught unexpected exception: %s' % (os.getpid(), - exc)) + logger.error( + "(%s) Caught unexpected exception: %s" % (os.getpid(), exc) + ) mon_pid = crontab_monitor.pid - logger.info('Wait for crontab monitors to clean up') + logger.info("Wait for crontab monitors to clean up") crontab_monitor.join(5) if crontab_monitor.is_alive(): - logger.warning("force kill %s: %s" % (mon_pid, - crontab_monitor.is_alive())) + logger.warning( + "force kill %s: %s" % (mon_pid, crontab_monitor.is_alive()) + ) crontab_monitor.terminate() else: - logger.debug('crontab monitor %s: done' % mon_pid) + logger.debug("crontab monitor %s: done" % mon_pid) - print('Cron handler daemon shutting down') - logger.info('(%s) Cron handler daemon shutting down' % main_pid) + print("Cron handler daemon shutting down") + logger.info("(%s) Cron handler daemon shutting down" % main_pid) sys.exit(0) diff --git a/mig/server/grid_events.py b/mig/server/grid_events.py index a0932ca2a..c57f4057c 100755 --- a/mig/server/grid_events.py +++ b/mig/server/grid_events.py @@ -32,8 +32,7 @@ Requires watchdog module (https://pypi.python.org/pypi/watchdog). """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import datetime import fnmatch @@ -48,36 +47,60 @@ import signal import sys import tempfile -import time import threading +import time try: + from watchdog.events import ( + DirCreatedEvent, + DirDeletedEvent, + DirModifiedEvent, + FileCreatedEvent, + FileDeletedEvent, + FileModifiedEvent, + PatternMatchingEventHandler, + ) from watchdog.observers import Observer - from watchdog.events import PatternMatchingEventHandler, \ - FileModifiedEvent, FileCreatedEvent, FileDeletedEvent, \ - DirModifiedEvent, DirCreatedEvent, DirDeletedEvent except ImportError: - print('ERROR: the python watchdog module is required for this daemon') + print("ERROR: the python watchdog module is required for this daemon") sys.exit(1) from mig.lib.daemon import check_stop, register_stop_handler, stop_running -from mig.lib.events import CACHE_EXPIRE_SIZE, DEFAULT_PERIOD, DEFAULT_TIME, \ - MISS_CACHE_TTL, RATE_LIMIT_FIELD, SETTLE_TIME_FIELD, TRIGGER_EVENT, \ - UNIT_PERIODS, get_path_expand_map, run_events_command +from mig.lib.events import ( + CACHE_EXPIRE_SIZE, + DEFAULT_PERIOD, + DEFAULT_TIME, + MISS_CACHE_TTL, + RATE_LIMIT_FIELD, + SETTLE_TIME_FIELD, + TRIGGER_EVENT, + UNIT_PERIODS, + get_path_expand_map, + run_events_command, +) from mig.shared.base import force_utf8 from mig.shared.conf import get_configuration_object -from mig.shared.defaults import valid_trigger_changes, workflows_log_name, \ - workflows_log_size, workflows_log_cnt, default_vgrid -from mig.shared.fileio import makedirs_rec, pickle, unpickle, scandir, walk +from mig.shared.defaults import ( + default_vgrid, + valid_trigger_changes, + workflows_log_cnt, + workflows_log_name, + workflows_log_size, +) +from mig.shared.fileio import makedirs_rec, pickle, scandir, unpickle, walk from mig.shared.job import fill_mrsl_template, new_job from mig.shared.listhandling import frange from mig.shared.logger import daemon_logger, register_hangup_handler -from mig.shared.safeinput import PARAM_START, PARAM_STOP, PARAM_JUMP +from mig.shared.safeinput import PARAM_JUMP, PARAM_START, PARAM_STOP from mig.shared.serial import load -from mig.shared.vgrid import vgrid_valid_entities, vgrid_add_workflow_jobs, \ - JOB_ID, JOB_CLIENT +from mig.shared.vgrid import ( + JOB_CLIENT, + JOB_ID, + vgrid_add_workflow_jobs, + vgrid_valid_entities, +) from mig.shared.vgridaccess import check_vgrid_access -from mig.shared.workflows import get_wp_map, CONF +from mig.shared.workflows import CONF, get_wp_map # Global trigger rule dictionaries with rules for all VGrids @@ -92,18 +115,18 @@ # Global state helpers used in a number of functions and methods shared_state = {} -shared_state['base_dir'] = None -shared_state['base_dir_len'] = 0 -shared_state['writable_dir'] = None -shared_state['writable_dir_len'] = 0 -shared_state['file_inotify'] = None -shared_state['file_handler'] = None -shared_state['rule_handler'] = None -shared_state['rule_inotify'] = None +shared_state["base_dir"] = None +shared_state["base_dir_len"] = 0 +shared_state["writable_dir"] = None +shared_state["writable_dir_len"] = 0 +shared_state["file_inotify"] = None +shared_state["file_handler"] = None +shared_state["rule_handler"] = None +shared_state["rule_inotify"] = None _hits_lock = threading.Lock() _rule_monitor_lock = threading.Lock() -(configuration, logger) = (None, None) +configuration, logger = (None, None) def make_fake_event(path, state, is_directory=False): @@ -111,11 +134,16 @@ def make_fake_event(path, state, is_directory=False): change is a directory or file. """ - file_map = {'modified': FileModifiedEvent, - 'created': FileCreatedEvent, - 'deleted': FileDeletedEvent} - dir_map = {'modified': DirModifiedEvent, - 'created': DirCreatedEvent, 'deleted': DirDeletedEvent} + file_map = { + "modified": FileModifiedEvent, + "created": FileCreatedEvent, + "deleted": FileDeletedEvent, + } + dir_map = { + "modified": DirModifiedEvent, + "created": DirCreatedEvent, + "deleted": DirDeletedEvent, + } if is_directory or os.path.isdir(path): fake = dir_map[state](path) else: @@ -143,7 +171,7 @@ def extract_time_in_secs(rule, field): pid = multiprocessing.current_process().pid - limit_str = rule.get(field, '') + limit_str = rule.get(field, "") if not limit_str: limit_str = "%s" % DEFAULT_TIME @@ -160,14 +188,15 @@ def extract_time_in_secs(rule, field): # print "ERROR: invalid time value %s ... fall back to defaults" % \ # limit_str - (unit_key, val_str) = (DEFAULT_PERIOD, DEFAULT_TIME) + unit_key, val_str = (DEFAULT_PERIOD, DEFAULT_TIME) else: val_str = limit_str try: secs = float(val_str) * UNIT_PERIODS[unit_key] except Exception as exc: - print('(%s) ERROR: failed to parse time %s (%s)!' % (pid, - limit_str, exc)) + print( + "(%s) ERROR: failed to parse time %s (%s)!" % (pid, limit_str, exc) + ) secs = 0.0 secs = max(secs, 0.0) return secs @@ -179,15 +208,15 @@ def extract_hit_limit(rule, field): within the last period_length seconds. """ - limit_str = rule.get(field, '') + limit_str = rule.get(field, "") # NOTE: format is 3(/m) or 52/h # split string on slash and fall back to no limit and default unit - parts = (limit_str.split('/', 1) + [DEFAULT_PERIOD])[:2] - (number, unit) = parts + parts = (limit_str.split("/", 1) + [DEFAULT_PERIOD])[:2] + number, unit = parts if not number.isdigit(): - number = '-1' + number = "-1" if unit not in UNIT_PERIODS: unit = DEFAULT_PERIOD return (int(number), UNIT_PERIODS[unit]) @@ -206,7 +235,7 @@ def update_rule_hits( """ pid = multiprocessing.current_process().pid - (_, hit_period) = extract_hit_limit(rule, RATE_LIMIT_FIELD) + _, hit_period = extract_hit_limit(rule, RATE_LIMIT_FIELD) settle_period = extract_time_in_secs(rule, SETTLE_TIME_FIELD) # logger.debug('(%s) update rule hits at %s for %s and %s %s %s' % ( @@ -219,12 +248,13 @@ def update_rule_hits( # )) _hits_lock.acquire() - rule_history = rule_hits.get(rule['rule_id'], []) + rule_history = rule_hits.get(rule["rule_id"], []) rule_history.append((path, change, ref, time_stamp)) max_period = max(hit_period, settle_period) - period_history = [i for i in rule_history if time_stamp - i[3] - <= max_period] - rule_hits[rule['rule_id']] = period_history + period_history = [ + i for i in rule_history if time_stamp - i[3] <= max_period + ] + rule_hits[rule["rule_id"]] = period_history _hits_lock.release() # logger.debug('(%s) updated rule hits for %s to %s' % (pid, @@ -237,16 +267,17 @@ def get_rule_hits(rule, limit_field): pid = multiprocessing.current_process().pid if limit_field == RATE_LIMIT_FIELD: - (hit_count, hit_period) = extract_hit_limit(rule, limit_field) + hit_count, hit_period = extract_hit_limit(rule, limit_field) elif limit_field == SETTLE_TIME_FIELD: - (hit_count, hit_period) = (1, extract_time_in_secs(rule, limit_field)) + hit_count, hit_period = (1, extract_time_in_secs(rule, limit_field)) else: - logger.error('(%s) get_rule_hits invalid limit_field %s' % - (pid, limit_field)) + logger.error( + "(%s) get_rule_hits invalid limit_field %s" % (pid, limit_field) + ) raise ValueError("got unexpected limit_field %r" % limit_field) _hits_lock.acquire() - rule_history = rule_hits.get(rule['rule_id'], []) + rule_history = rule_hits.get(rule["rule_id"], []) res = (rule_history, hit_count, hit_period) _hits_lock.release() @@ -258,8 +289,7 @@ def get_rule_hits(rule, limit_field): def get_path_hits(rule, path, limit_field): """find path hit details""" - (rule_history, hit_count, hit_period) = get_rule_hits(rule, - limit_field) + rule_history, hit_count, hit_period = get_rule_hits(rule, limit_field) path_history = [i for i in rule_history if i[0] == path] return (path_history, hit_count, hit_period) @@ -276,15 +306,15 @@ def above_path_limit( pid = multiprocessing.current_process().pid - (path_history, hit_count, hit_period) = get_path_hits(rule, path, - limit_field) + path_history, hit_count, hit_period = get_path_hits(rule, path, limit_field) if hit_count <= 0 or hit_period <= 0: # logger.debug('(%s) no %s limit set' % (pid, limit_field)) return False - period_history = [i for i in path_history if time_stamp - i[3] - <= hit_period] + period_history = [ + i for i in path_history if time_stamp - i[3] <= hit_period + ] # logger.debug('(%s) above path %s test found %s vs %d' % (pid, # limit_field, period_history, hit_count)) @@ -299,12 +329,12 @@ def show_path_hits(rule, path, limit_field): pid = multiprocessing.current_process().pid - msg = '' - (path_history, hit_count, hit_period) = get_path_hits(rule, path, - limit_field) - msg += \ - '(%s) found %d entries in trigger history and limit is %d per %s s' \ + msg = "" + path_history, hit_count, hit_period = get_path_hits(rule, path, limit_field) + msg += ( + "(%s) found %d entries in trigger history and limit is %d per %s s" % (pid, len(path_history), hit_count, hit_period) + ) return msg @@ -323,10 +353,10 @@ def wait_settled( pid = multiprocessing.current_process().pid limit_field = SETTLE_TIME_FIELD - (path_history, _, hit_period) = get_path_hits(rule, path, - limit_field) - period_history = [i for i in path_history if time_stamp - i[3] - <= hit_period] + path_history, _, hit_period = get_path_hits(rule, path, limit_field) + period_history = [ + i for i in path_history if time_stamp - i[3] <= hit_period + ] # logger.debug('(%s) wait_settled: path %s, change %s, settle_secs %s' # % (pid, path, change, settle_secs)) @@ -340,8 +370,7 @@ def wait_settled( # Thus we can just take the smallest and subtract from settle_secs # to always wait the remaining part of settle_secs. - remain = settle_secs - min([time_stamp - i[3] for i in - period_history]) + remain = settle_secs - min([time_stamp - i[3] for i in period_history]) # logger.debug('(%s) wait_settled: remain %.1f , period_history %s' # % (pid, remain, period_history)) @@ -359,8 +388,10 @@ def recently_modified(path, time_stamp, slack=2.0): try: stat_res = os.stat(path) - result = stat_res.st_mtime == stat_res.st_atime \ + result = ( + stat_res.st_mtime == stat_res.st_atime or stat_res.st_mtime > time_stamp - slack + ) except OSError as exc: # If we get an OSError, *path* is most likely deleted @@ -374,15 +405,14 @@ def recently_modified(path, time_stamp, slack=2.0): def strip_base_dirs(path): """strips base directories from a given path""" - if shared_state['base_dir'] in path: - return path[shared_state['base_dir_len']:] - if shared_state['writable_dir'] in path: - return path[shared_state['writable_dir_len']:] + if shared_state["base_dir"] in path: + return path[shared_state["base_dir_len"] :] + if shared_state["writable_dir"] in path: + return path[shared_state["writable_dir_len"] :] return path class MiGRuleEventHandler(PatternMatchingEventHandler): - """Rule pattern-matching event handler to take care of VGrid rule changes and update the global rule database. """ @@ -397,9 +427,12 @@ def __init__( """Constructor""" PatternMatchingEventHandler.__init__( - self, patterns=patterns, ignore_patterns=ignore_patterns, + self, + patterns=patterns, + ignore_patterns=ignore_patterns, ignore_directories=ignore_directories, - case_sensitive=case_sensitive) + case_sensitive=case_sensitive, + ) def __update_rule_monitor( self, @@ -410,25 +443,26 @@ def __update_rule_monitor( pid = multiprocessing.current_process().pid - if state == 'created': + if state == "created": # logger.debug('(%s) Updating rule monitor for src_path: %s, event: %s' # % (pid, src_path, state)) - print('(%s) Updating rule monitor for src_path: %s, event: %s' - % (pid, src_path, state)) + print( + "(%s) Updating rule monitor for src_path: %s, event: %s" + % (pid, src_path, state) + ) if os.path.exists(src_path): # _rule_monitor_lock.acquire() - if src_path not in shared_state['rule_inotify']._wd_for_path: + if src_path not in shared_state["rule_inotify"]._wd_for_path: # logger.debug('(%s) Adding watch for: %s' % (pid, # src_path)) - shared_state['rule_inotify'].add_watch( - force_utf8(src_path)) + shared_state["rule_inotify"].add_watch(force_utf8(src_path)) # Fire 'modified' events for all dirs and files in subpath # to ensure that all rule files are loaded @@ -439,16 +473,17 @@ def __update_rule_monitor( # logger.debug('(%s) Dispatch DirCreatedEvent for: %s' # % (pid, ent.path)) - shared_state['rule_handler'].dispatch( - DirCreatedEvent(ent.path)) - elif ent.path.find(configuration.vgrid_triggers) \ - > -1: + shared_state["rule_handler"].dispatch( + DirCreatedEvent(ent.path) + ) + elif ent.path.find(configuration.vgrid_triggers) > -1: # logger.debug('(%s) Dispatch FileCreatedEvent for: %s' # % (pid, ent.path)) - shared_state['rule_handler'].dispatch( - FileCreatedEvent(ent.path)) + shared_state["rule_handler"].dispatch( + FileCreatedEvent(ent.path) + ) # else: # logger.debug('(%s) rule_monitor watch already exists for: %s' @@ -471,24 +506,28 @@ def update_rules(self, event): # logger.debug('(%s) %s -> Updating rule for: %s' % (pid, # state, src_path)) - rel_path = src_path[len(configuration.vgrid_home):] - vgrid_name = rel_path[:-len(configuration.vgrid_triggers) - - 1] - vgrid_prefix = os.path.join(configuration.vgrid_files_home, - vgrid_name, '') - logger.info('(%s) refresh %s rules from %s' % - (pid, vgrid_name, src_path)) + rel_path = src_path[len(configuration.vgrid_home) :] + vgrid_name = rel_path[: -len(configuration.vgrid_triggers) - 1] + vgrid_prefix = os.path.join( + configuration.vgrid_files_home, vgrid_name, "" + ) + logger.info( + "(%s) refresh %s rules from %s" % (pid, vgrid_name, src_path) + ) try: raw_rules = load(src_path) # NOTE: manually filter out any broken rules once and for all # this is like if loaded with vgrid_triggers() - new_rules = vgrid_valid_entities(configuration, vgrid_name, - 'triggers', raw_rules) + new_rules = vgrid_valid_entities( + configuration, vgrid_name, "triggers", raw_rules + ) except Exception as exc: new_rules = [] - if state != 'deleted': - logger.error('(%s) failed to load event handler rules from %s (%s)' - % (pid, src_path, exc)) + if state != "deleted": + logger.error( + "(%s) failed to load event handler rules from %s (%s)" + % (pid, src_path, exc) + ) # logger.debug("(%s) loaded new rules from '%s':\n%s" % (pid, # src_path, new_rules)) @@ -498,11 +537,16 @@ def update_rules(self, event): # NOTE: we need to iterate over a copy of keys for in-place edits for target_path in list(all_rules): - all_rules[target_path] = [i for i in - all_rules[target_path] if i['vgrid_name'] - != vgrid_name] - remain_rules = [i for i in all_rules[target_path] - if i['vgrid_name'] != vgrid_name] + all_rules[target_path] = [ + i + for i in all_rules[target_path] + if i["vgrid_name"] != vgrid_name + ] + remain_rules = [ + i + for i in all_rules[target_path] + if i["vgrid_name"] != vgrid_name + ] if remain_rules: all_rules[target_path] = remain_rules else: @@ -514,13 +558,14 @@ def update_rules(self, event): del all_rules[target_path] for entry in new_rules: - rule_id = entry['rule_id'] - path = entry['path'] - logger.info('(%s) updating rule: %s, path: %s, entry:\n%s' - % (pid, rule_id, path, entry)) + rule_id = entry["rule_id"] + path = entry["path"] + logger.info( + "(%s) updating rule: %s, path: %s, entry:\n%s" + % (pid, rule_id, path, entry) + ) abs_path = os.path.join(vgrid_prefix, path) - all_rules[abs_path] = all_rules.get(abs_path, []) \ - + [entry] + all_rules[abs_path] = all_rules.get(abs_path, []) + [entry] # logger.debug('(%s) all rules:\n%s' % (pid, all_rules)) # else: @@ -544,7 +589,6 @@ def on_deleted(self, event): class MiGFileEventHandler(PatternMatchingEventHandler): - """File pattern-matching event handler to take care of VGrid file changes and the corresponding action triggers. """ @@ -555,14 +599,17 @@ def __init__( ignore_patterns=None, ignore_directories=False, case_sensitive=False, - sub_vgrids=None + sub_vgrids=None, ): """Constructor""" PatternMatchingEventHandler.__init__( - self, patterns=patterns, ignore_patterns=ignore_patterns, + self, + patterns=patterns, + ignore_patterns=ignore_patterns, ignore_directories=ignore_directories, - case_sensitive=case_sensitive) + case_sensitive=case_sensitive, + ) self.sub_vgrids = sub_vgrids def __workflow_log( @@ -570,26 +617,25 @@ def __workflow_log( configuration, vgrid_name, msg, - level='info', + level="info", ): """Wrapper to send a single msg to vgrid workflows page log file""" - log_name = '%s.%s' % (configuration.vgrid_triggers, - workflows_log_name) - log_path = os.path.join(configuration.vgrid_home, vgrid_name, - log_name) - workflows_logger = logging.getLogger('workflows') + log_name = "%s.%s" % (configuration.vgrid_triggers, workflows_log_name) + log_path = os.path.join(configuration.vgrid_home, vgrid_name, log_name) + workflows_logger = logging.getLogger("workflows") workflows_logger.setLevel(logging.INFO) handler = logging.handlers.RotatingFileHandler( - log_path, maxBytes=workflows_log_size, - backupCount=workflows_log_cnt - 1) - formatter = \ - logging.Formatter('%(asctime)s %(levelname)s %(message)s') + log_path, + maxBytes=workflows_log_size, + backupCount=workflows_log_cnt - 1, + ) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") handler.setFormatter(formatter) workflows_logger.addHandler(handler) - if level == 'error': + if level == "error": workflows_logger.error(msg) - elif level == 'warning': + elif level == "warning": workflows_logger.warning(msg) else: workflows_logger.info(msg) @@ -605,7 +651,7 @@ def __workflow_err( ): """Wrapper to send a single error msg to vgrid workflows page log""" - self.__workflow_log(configuration, vgrid_name, msg, 'error') + self.__workflow_log(configuration, vgrid_name, msg, "error") def __workflow_warn( self, @@ -615,7 +661,7 @@ def __workflow_warn( ): """Wrapper to send a single warning msg to vgrid workflows page log""" - self.__workflow_log(configuration, vgrid_name, msg, 'warning') + self.__workflow_log(configuration, vgrid_name, msg, "warning") def __workflow_info( self, @@ -625,7 +671,7 @@ def __workflow_info( ): """Wrapper to send a single error msg to vgrid workflows page log""" - self.__workflow_log(configuration, vgrid_name, msg, 'info') + self.__workflow_log(configuration, vgrid_name, msg, "info") def __add_trigger_job_ent( self, @@ -638,42 +684,48 @@ def __add_trigger_job_ent( result = True pid = multiprocessing.current_process().pid - vgrid_name = rule['vgrid_name'] + vgrid_name = rule["vgrid_name"] trigger_job_dir = os.path.join( - configuration.vgrid_home, os.path.join(vgrid_name, os.path.join( - '.%s.jobs' % configuration.vgrid_triggers, 'pending_states'))) + configuration.vgrid_home, + os.path.join( + vgrid_name, + os.path.join( + ".%s.jobs" % configuration.vgrid_triggers, "pending_states" + ), + ), + ) trigger_job_filepath = os.path.join(trigger_job_dir, jobid) if makedirs_rec(trigger_job_dir, configuration): trigger_job_dict = { - 'jobid': jobid, - 'owner': rule['run_as'], - 'rule': rule, - 'event': {}, + "jobid": jobid, + "owner": rule["run_as"], + "rule": rule, + "event": {}, } - src_path = '' - if hasattr(event, 'src_path'): + src_path = "" + if hasattr(event, "src_path"): src_path = event.src_path - dest_path = '' - if hasattr(event, 'dest_path'): + dest_path = "" + if hasattr(event, "dest_path"): dest_path = event.dest_path - trigger_job_dict['event']['src_path'] = src_path - trigger_job_dict['event']['dest_path'] = dest_path - trigger_job_dict['event']['time_stamp'] = event.time_stamp - trigger_job_dict['event']['event_type'] = event.event_type - trigger_job_dict['event']['is_directory'] = \ - event.is_directory + trigger_job_dict["event"]["src_path"] = src_path + trigger_job_dict["event"]["dest_path"] = dest_path + trigger_job_dict["event"]["time_stamp"] = event.time_stamp + trigger_job_dict["event"]["event_type"] = event.event_type + trigger_job_dict["event"]["is_directory"] = event.is_directory # logger.debug('(%s) trigger_job_dict: %s' % (pid, # trigger_job_dict)) - if not pickle(trigger_job_dict, trigger_job_filepath, - logger): + if not pickle(trigger_job_dict, trigger_job_filepath, logger): result = False else: - logger.error('(%s) Failed to create trigger job dir: %s' - % (pid, trigger_job_dir)) + logger.error( + "(%s) Failed to create trigger job dir: %s" + % (pid, trigger_job_dir) + ) result = False return result @@ -692,27 +744,45 @@ def __handle_trigger( state = event.event_type src_path = event.src_path time_stamp = event.time_stamp - _chain = getattr(event, '_chain', [(src_path, state)]) + _chain = getattr(event, "_chain", [(src_path, state)]) rel_src = strip_base_dirs(src_path).lstrip(os.sep) vgrid_prefix = os.path.join( - shared_state['base_dir'], rule['vgrid_name']) - logger.info('(%s) in handling of %s for %s %s' % - (pid, rule['action'], state, rel_src)) + shared_state["base_dir"], rule["vgrid_name"] + ) + logger.info( + "(%s) in handling of %s for %s %s" + % (pid, rule["action"], state, rel_src) + ) above_limit = False # Run settle time check first to only trigger rate limit if settled - for (name, field) in [('settle time', SETTLE_TIME_FIELD), - ('rate limit', RATE_LIMIT_FIELD)]: + for name, field in [ + ("settle time", SETTLE_TIME_FIELD), + ("rate limit", RATE_LIMIT_FIELD), + ]: if above_path_limit(rule, src_path, field, time_stamp): above_limit = True - logger.warning('(%s) skip %s due to %s: %s' % - (pid, src_path, name, show_path_hits( - rule, src_path, field))) - self.__workflow_warn(configuration, rule['vgrid_name'], - '(%s) skip %s trigger due to %s: %s' % - (pid, rel_src, name, show_path_hits( - rule, src_path, field))) + logger.warning( + "(%s) skip %s due to %s: %s" + % ( + pid, + src_path, + name, + show_path_hits(rule, src_path, field), + ) + ) + self.__workflow_warn( + configuration, + rule["vgrid_name"], + "(%s) skip %s trigger due to %s: %s" + % ( + pid, + rel_src, + name, + show_path_hits(rule, src_path, field), + ), + ) break # TODO: consider if we should skip modified when just created @@ -720,25 +790,35 @@ def __handle_trigger( # We receive modified events even when only atime changed - ignore them # but make sure we handle our fake trigger-modified events - if state == 'modified' and not is_fake_event(event) \ - and not recently_modified(src_path, time_stamp): - logger.info('(%s) skip %s which only changed atime' % (pid, - src_path)) - self.__workflow_info(configuration, rule['vgrid_name'], - 'skip %s modified access time only event' - % rel_src) + if ( + state == "modified" + and not is_fake_event(event) + and not recently_modified(src_path, time_stamp) + ): + logger.info( + "(%s) skip %s which only changed atime" % (pid, src_path) + ) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "skip %s modified access time only event" % rel_src, + ) return # Always update here to get trigger hits even for limited events - update_rule_hits(rule, src_path, state, '', time_stamp) + update_rule_hits(rule, src_path, state, "", time_stamp) if above_limit: return - logger.info('(%s) proceed with handling of %s for %s %s' - % (pid, rule['action'], state, rel_src)) - self.__workflow_info(configuration, rule['vgrid_name'], - 'handle %s for %s %s' % (rule['action'], - state, rel_src)) + logger.info( + "(%s) proceed with handling of %s for %s %s" + % (pid, rule["action"], state, rel_src) + ) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "handle %s for %s %s" % (rule["action"], state, rel_src), + ) settle_secs = extract_time_in_secs(rule, SETTLE_TIME_FIELD) if settle_secs > 0.0: wait_secs = settle_secs @@ -749,41 +829,47 @@ def __handle_trigger( # target_path, rule)) while wait_secs > 0.0: - logger.info('(%s) wait %.1fs for %s file events to settle down' - % (pid, wait_secs, src_path)) - self.__workflow_info(configuration, rule['vgrid_name'], - 'wait %.1fs for events on %s to settle' - % (wait_secs, rel_src)) + logger.info( + "(%s) wait %.1fs for %s file events to settle down" + % (pid, wait_secs, src_path) + ) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "wait %.1fs for events on %s to settle" % (wait_secs, rel_src), + ) time.sleep(wait_secs) # logger.debug('(%s) slept %.1fs for %s file events to settle down' # % (pid, wait_secs, src_path)) time_stamp += wait_secs - wait_secs = wait_settled(rule, src_path, state, - settle_secs, time_stamp) + wait_secs = wait_settled( + rule, src_path, state, settle_secs, time_stamp + ) # TODO: perhaps we should discriminate on files and dirs here? # TODO: logger does not actually work here, only __workflow_X logs - if rule['action'] in ['trigger-%s' % i for i in - valid_trigger_changes]: - change = rule['action'].replace('trigger-', '') + if rule["action"] in ["trigger-%s" % i for i in valid_trigger_changes]: + change = rule["action"].replace("trigger-", "") # Expand dynamic variables in argument once and for all expand_map = get_path_expand_map(rel_src, rule, state) - for argument in rule['arguments']: + for argument in rule["arguments"]: filled_argument = argument - for (key, val) in expand_map.items(): + for key, val in expand_map.items(): filled_argument = filled_argument.replace(key, val) # logger.debug('(%s) expanded argument %s to %s' % (pid, # argument, filled_argument)) - self.__workflow_info(configuration, rule['vgrid_name'], - 'expanded argument %s to %s' % - (argument, filled_argument)) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "expanded argument %s to %s" % (argument, filled_argument), + ) pattern = os.path.join(vgrid_prefix, filled_argument) for path in glob.glob(pattern): rel_path = strip_base_dirs(path) @@ -792,71 +878,81 @@ def __handle_trigger( # Prevent obvious trigger chain cycles if (path, change) in _chain[:-1]: - flat_chain = ['%s : %s' % pair for pair in - _chain] - chain_str = ' <-> '.join(flat_chain) + flat_chain = ["%s : %s" % pair for pair in _chain] + chain_str = " <-> ".join(flat_chain) rel_chain_str = strip_base_dirs(chain_str) - logger.warning('(%s) breaking trigger cycle %s' - % (pid, chain_str)) - self.__workflow_warn(configuration, - rule['vgrid_name'], - 'breaking trigger cycle %s' - % rel_chain_str) + logger.warning( + "(%s) breaking trigger cycle %s" % (pid, chain_str) + ) + self.__workflow_warn( + configuration, + rule["vgrid_name"], + "breaking trigger cycle %s" % rel_chain_str, + ) continue fake = make_fake_event(path, change) fake._chain = _chain - logger.info('(%s) trigger %s event on %s' % (pid, - change, path)) + logger.info( + "(%s) trigger %s event on %s" % (pid, change, path) + ) self.__workflow_info( - configuration, rule['vgrid_name'], - 'trigger %s event on %s' % (change, rel_path)) + configuration, + rule["vgrid_name"], + "trigger %s event on %s" % (change, rel_path), + ) self.handle_event(fake) - elif rule['action'] == 'submit': + elif rule["action"] == "submit": temp_dir = tempfile.mkdtemp() # Expand dynamic variables in argument once and for all expand_map = get_path_expand_map(rel_src, rule, state) try: - for job_template in rule['templates']: - pattern_id = rule['pattern_id'] - pattern_map = get_wp_map( - configuration).get(pattern_id, None) + for job_template in rule["templates"]: + pattern_id = rule["pattern_id"] + pattern_map = get_wp_map(configuration).get( + pattern_id, None + ) if not pattern_map: - raise Exception('(%s) pattern entry %s is missing' - % (pid, pattern_id)) + raise Exception( + "(%s) pattern entry %s is missing" + % (pid, pattern_id) + ) pattern = pattern_map[CONF] # logger.debug('DM setting up logging for job with ' # 'pattern: %s' % pattern) - rule_id = rule['rule_id'] + rule_id = rule["rule_id"] - recipe_list = \ - [(recipe['name'], recipe['persistence_id']) - for recipe - in pattern['trigger_recipes'][rule_id].values()] + recipe_list = [ + (recipe["name"], recipe["persistence_id"]) + for recipe in pattern["trigger_recipes"][ + rule_id + ].values() + ] workflow_dict = { - 'trigger_id': rule_id, - 'trigger_path': src_path, - 'trigger_time': - datetime.datetime.fromtimestamp(time_stamp), - 'pattern_id': pattern['persistence_id'], - 'pattern_name': pattern['name'], - 'recipes': recipe_list + "trigger_id": rule_id, + "trigger_path": src_path, + "trigger_time": datetime.datetime.fromtimestamp( + time_stamp + ), + "pattern_id": pattern["persistence_id"], + "pattern_name": pattern["name"], + "recipes": recipe_list, } # logger.debug('DM workflow_dict: %s' % workflow_dict) # In cases where parameterize over is defined in a workflow # Pattern, many different but related jobs will need to be # created. - if pattern['parameterize_over']: + if pattern["parameterize_over"]: all_values = [] - for (var, sweep) in pattern['parameterize_over'].items(): + for var, sweep in pattern["parameterize_over"].items(): start = float(sweep[PARAM_START]) stop = float(sweep[PARAM_STOP]) @@ -874,109 +970,155 @@ def __handle_trigger( for job_param in job_params: self.__schedule_job( - job_template, rel_src, state, rule, expand_map, - pid, event, target_path, temp_dir, - workflow_job=workflow_dict, param=job_param) + job_template, + rel_src, + state, + rule, + expand_map, + pid, + event, + target_path, + temp_dir, + workflow_job=workflow_dict, + param=job_param, + ) else: self.__schedule_job( - job_template, rel_src, state, rule, expand_map, - pid, event, target_path, temp_dir, - workflow_job=workflow_dict) + job_template, + rel_src, + state, + rule, + expand_map, + pid, + event, + target_path, + temp_dir, + workflow_job=workflow_dict, + ) except Exception as exc: - logger.error('(%s) failed to submit job(s) for %s: %s' - % (pid, target_path, exc)) - self.__workflow_err(configuration, rule['vgrid_name'], - 'failed to submit job for %s: %s' - % (rel_src, exc)) + logger.error( + "(%s) failed to submit job(s) for %s: %s" + % (pid, target_path, exc) + ) + self.__workflow_err( + configuration, + rule["vgrid_name"], + "failed to submit job for %s: %s" % (rel_src, exc), + ) try: shutil.rmtree(temp_dir) except Exception as exc: - logger.warning('(%s) clean up after submit failed: %s' - % (pid, exc)) - elif rule['action'] == 'command': + logger.warning( + "(%s) clean up after submit failed: %s" % (pid, exc) + ) + elif rule["action"] == "command": # Expand dynamic variables in argument once and for all expand_map = get_path_expand_map(rel_src, rule, state) - command_list = (rule['arguments'])[:1] - for argument in (rule['arguments'])[1:]: + command_list = (rule["arguments"])[:1] + for argument in (rule["arguments"])[1:]: filled_argument = argument - for (key, val) in expand_map.items(): + for key, val in expand_map.items(): filled_argument = filled_argument.replace(key, val) - self.__workflow_info(configuration, rule['vgrid_name'], - 'expanded argument %s to %s' % - (argument, filled_argument)) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "expanded argument %s to %s" % (argument, filled_argument), + ) command_list.append(filled_argument) try: - run_events_command(command_list, target_path, rule, - configuration) - logger.info('(%s) done running command for %s: %s' % - (pid, target_path, ' '.join(command_list))) - self.__workflow_info(configuration, rule['vgrid_name'], - 'ran command: %s' - % ' '.join(command_list)) + run_events_command( + command_list, target_path, rule, configuration + ) + logger.info( + "(%s) done running command for %s: %s" + % (pid, target_path, " ".join(command_list)) + ) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "ran command: %s" % " ".join(command_list), + ) except Exception as exc: - command_str = ' '.join(command_list) - logger.error('(%s) failed to run command for %s: %s (%s)' % - (pid, target_path, command_str, exc)) - self.__workflow_err(configuration, rule['vgrid_name'], - 'failed to run command for %s: %s (%s)' % - (rel_src, command_str, exc)) + command_str = " ".join(command_list) + logger.error( + "(%s) failed to run command for %s: %s (%s)" + % (pid, target_path, command_str, exc) + ) + self.__workflow_err( + configuration, + rule["vgrid_name"], + "failed to run command for %s: %s (%s)" + % (rel_src, command_str, exc), + ) else: - logger.error('(%s) unsupported action: %s' % (pid, - rule['action'])) + logger.error("(%s) unsupported action: %s" % (pid, rule["action"])) - def __schedule_job(self, job_template, rel_src, state, rule, expand_map, - pid, event, target_path, temp_dir, workflow_job=None, - param=None): + def __schedule_job( + self, + job_template, + rel_src, + state, + rule, + expand_map, + pid, + event, + target_path, + temp_dir, + workflow_job=None, + param=None, + ): """Creates a new job from a triggered event by calling the mRSL file creation, enqueueing the job, and updating the VGrid history. Takes optional parameter 'param', used in workflow Patterns with the - 'parameterize_over' variable. """ + 'parameterize_over' variable.""" mrsl_fd = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) mrsl_path = mrsl_fd.name if not fill_mrsl_template( - job_template, - mrsl_fd, - rel_src, - state, - rule, - expand_map, - configuration, - param_list=param + job_template, + mrsl_fd, + rel_src, + state, + rule, + expand_map, + configuration, + param_list=param, ): - raise Exception('fill template failed') + raise Exception("fill template failed") # logger.debug('(%s) filled template for %s in %s' # % (pid, target_path, mrsl_path)) - (success, msg, jobid) = new_job( - mrsl_path, rule['run_as'], configuration, False, - returnjobid=True, workflow_job=workflow_job) + success, msg, jobid = new_job( + mrsl_path, + rule["run_as"], + configuration, + False, + returnjobid=True, + workflow_job=workflow_job, + ) if success: - self.__add_trigger_job_ent(configuration, - event, rule, jobid) + self.__add_trigger_job_ent(configuration, event, rule, jobid) # update vgrid workflow jobs list - vgrid = rule['vgrid_name'] + vgrid = rule["vgrid_name"] if vgrid != default_vgrid: - job_queue_entry = { - JOB_ID: jobid, - JOB_CLIENT: rule['run_as'] - } - vgrid_add_workflow_jobs(configuration, vgrid, - [job_queue_entry]) - - logger.info('(%s) submitted job for %s: %s' - % (pid, target_path, msg)) - self.__workflow_info(configuration, - rule['vgrid_name'], - 'submitted job for %s: %s' % - (rel_src, msg)) + job_queue_entry = {JOB_ID: jobid, JOB_CLIENT: rule["run_as"]} + vgrid_add_workflow_jobs(configuration, vgrid, [job_queue_entry]) + + logger.info( + "(%s) submitted job for %s: %s" % (pid, target_path, msg) + ) + self.__workflow_info( + configuration, + rule["vgrid_name"], + "submitted job for %s: %s" % (rel_src, msg), + ) else: raise Exception(msg) @@ -989,12 +1131,15 @@ def __detect_symlink(self, event): src_path = event.src_path # Check if really is not a dir, could be a symlink. - if not event.is_directory \ - and os.path.islink(src_path) \ - and os.path.isdir(src_path): + if ( + not event.is_directory + and os.path.islink(src_path) + and os.path.isdir(src_path) + ): logger.debug( "(%s) path %s is symlink to a directory. Updating " - "event accordingly" % (pid, src_path)) + "event accordingly" % (pid, src_path) + ) event.is_directory = True def __update_file_monitor(self, event): @@ -1006,7 +1151,7 @@ def __update_file_monitor(self, event): # If dir_modified is due to a file event we ignore it - if is_directory and state == 'created': + if is_directory and state == "created": src_path = self.__get_masked_event_path(event) rel_path = strip_base_dirs(src_path) @@ -1027,9 +1172,8 @@ def __update_file_monitor(self, event): vgrid_dir_cache[rel_path] = {} rel_path_ctime = os.path.getctime(src_path) rel_path_mtime = os.path.getmtime(src_path) - add_vgrid_file_monitor_watch(configuration, - rel_path) - vgrid_dir_cache[rel_path]['mtime'] = rel_path_mtime + add_vgrid_file_monitor_watch(configuration, rel_path) + vgrid_dir_cache[rel_path]["mtime"] = rel_path_mtime # Check if sub paths or files were changed # For create this occurs by eg. mkdir -p 'path/subpath/subpath2' @@ -1039,23 +1183,26 @@ def __update_file_monitor(self, event): if ent.is_dir(follow_symlinks=True): vgrid_sub_path = strip_base_dirs(ent.path) - if not vgrid_sub_path in \ - vgrid_dir_cache or \ - vgrid_dir_cache[vgrid_sub_path]['mtime'] \ - < rel_path_ctime: + if ( + not vgrid_sub_path in vgrid_dir_cache + or vgrid_dir_cache[vgrid_sub_path]["mtime"] + < rel_path_ctime + ): # logger.debug('(%s) %s -> Dispatch DirCreatedEvent for: %s' # % (pid, src_path, ent.path)) - shared_state['file_handler'].dispatch( - DirCreatedEvent(ent.path)) + shared_state["file_handler"].dispatch( + DirCreatedEvent(ent.path) + ) elif ent.is_file(follow_symlinks=True): # logger.debug('(%s) %s -> Dispatch FileCreatedEvent for: %s' # % (pid, src_path, ent.path)) - shared_state['file_handler'].dispatch( - FileCreatedEvent(ent.path)) + shared_state["file_handler"].dispatch( + FileCreatedEvent(ent.path) + ) except OSError as exc: # If we get an OSError, src_path was most likely deleted @@ -1073,8 +1220,11 @@ def _get_event_id(self, event): """Build a simplified string form of event properties for use in the trigger miss cache. """ - return "path=%s;state=%s;isdir=%s" % (event.src_path, event.event_type, - event.is_directory) + return "path=%s;state=%s;isdir=%s" % ( + event.src_path, + event.event_type, + event.is_directory, + ) def _update_recent_miss(self, event, hit): """Update the internal cache of recent events with no matching trigger @@ -1102,7 +1252,7 @@ def _update_recent_miss(self, event, hit): if len(miss_cache) < CACHE_EXPIRE_SIZE: return - logger.info('(%s) expire all old entries in miss cache' % pid) + logger.info("(%s) expire all old entries in miss cache" % pid) now = time.time() # NOTE: we need to iterate over a copy of keys for in-place edits @@ -1110,8 +1260,10 @@ def _update_recent_miss(self, event, hit): time_stamp = miss_cache[event_id] if time_stamp + MISS_CACHE_TTL < now: del miss_cache[event_id] - logger.info('(%s) miss cache entries left after expire: %d' % - (pid, len(miss_cache))) + logger.info( + "(%s) miss cache entries left after expire: %d" + % (pid, len(miss_cache)) + ) def _recent_miss(self, event): """Check if we recently dismissed this kind of event. We store a small @@ -1138,17 +1290,17 @@ def __get_masked_event_path(self, event): # probably be removed in favour of only monitoring writable directory. src_path = event.src_path - if ':' in src_path: + if ":" in src_path: for sub_vgrid in self.sub_vgrids: if sub_vgrid in src_path: - mask = sub_vgrid.replace(':', os.path.sep) + mask = sub_vgrid.replace(":", os.path.sep) masked = src_path.replace(sub_vgrid, mask) if os.path.exists(masked): src_path = masked masked = src_path.replace( - configuration.vgrid_files_writable, - configuration.vgrid_files_home) + configuration.vgrid_files_writable, configuration.vgrid_files_home + ) if os.path.exists(masked): src_path = masked @@ -1174,21 +1326,23 @@ def run_handler(self, event): # list(all_rules), src_path)) if self._recent_miss(event): - logger.debug('(%s) skip cached miss %s event for src_path: %s' % - (pid, state, src_path)) + logger.debug( + "(%s) skip cached miss %s event for src_path: %s" + % (pid, state, src_path) + ) return rule_hit = False # Each target_path pattern has one or more rules associated - for (target_path, rule_list) in all_rules.items(): + for target_path, rule_list in all_rules.items(): # Do not use ordinary fnmatch as it lets '*' match anything # including '/' which leads to greedy matching in subdirs recursive_regexp = fnmatch.translate(target_path) - direct_regexp = recursive_regexp.replace('.*', '[^/]*') + direct_regexp = recursive_regexp.replace(".*", "[^/]*") recursive_hit = re.match(recursive_regexp, src_path) direct_hit = re.match(direct_regexp, src_path) @@ -1202,28 +1356,27 @@ def run_handler(self, event): # Rules may listen for only file or dir events and with # recursive directory search - if is_directory and not rule.get('match_dirs', - False): + if is_directory and not rule.get("match_dirs", False): # logger.debug('(%s) skip event %s handling for dir: %s' # % (pid, rule['rule_id'], src_path)) continue - if not is_directory and not rule.get('match_files', - True): + if not is_directory and not rule.get("match_files", True): # logger.debug('(%s) skip %s event handling for file: %s' # % (pid, rule['rule_id'], src_path)) continue - if not direct_hit and not rule.get('match_recursive', - False): + if not direct_hit and not rule.get( + "match_recursive", False + ): # logger.debug('(%s) skip %s recurse event handling for: %s' # % (pid, rule['rule_id'], src_path)) continue - if not state in rule['changes']: + if not state in rule["changes"]: # logger.debug('(%s) skip %s %s event handling for: %s' # % (pid, rule['rule_id'], state, @@ -1241,15 +1394,19 @@ def run_handler(self, event): # (pid, rule['run_as'], rule['vgrid_name'], # rule['rule_id'])) - if not check_vgrid_access(configuration, rule['run_as'], - rule['vgrid_name']): - logger.warning('(%s) no such user in vgrid: %s' - % (pid, rule['run_as'])) + if not check_vgrid_access( + configuration, rule["run_as"], rule["vgrid_name"] + ): + logger.warning( + "(%s) no such user in vgrid: %s" + % (pid, rule["run_as"]) + ) continue - logger.info('(%s) trigger %s for src_path: %s -> %s' - % (pid, rule['action'], src_path, - rule)) + logger.info( + "(%s) trigger %s for src_path: %s -> %s" + % (pid, rule["action"], src_path, rule) + ) rule_hit = True @@ -1259,9 +1416,10 @@ def run_handler(self, event): waiting_for_worker_resources = True while waiting_for_worker_resources: try: - worker = \ - multiprocessing.Process(target=self.__handle_trigger, - args=(event, target_path, rule)) + worker = multiprocessing.Process( + target=self.__handle_trigger, + args=(event, target_path, rule), + ) worker.daemon = True worker.start() waiting_for_worker_resources = False @@ -1326,8 +1484,10 @@ def on_moved(self, event): event for a move between different filesystems or symlinked dirs. """ - for (change, path) in [('created', event.dest_path), - ('deleted', event.src_path)]: + for change, path in [ + ("created", event.dest_path), + ("deleted", event.src_path), + ]: fake = make_fake_event(path, change, event.is_directory) self.handle_event(fake) @@ -1339,24 +1499,25 @@ def add_vgrid_file_monitor_watch(configuration, path): vgrid_files_path = os.path.join(configuration.vgrid_files_home, path) vgrid_files_writable = os.path.join( - configuration.vgrid_files_writable, path) + configuration.vgrid_files_writable, path + ) - if path not in shared_state['file_inotify']._wd_for_path: - shared_state['file_inotify'].add_watch(force_utf8(vgrid_files_path)) + if path not in shared_state["file_inotify"]._wd_for_path: + shared_state["file_inotify"].add_watch(force_utf8(vgrid_files_path)) # logger.debug('(%s) Adding watch for: %s with path: %s' % (pid, # vgrid_files_path, path)) if os.path.sep not in path: - shared_state['file_inotify'].add_watch( - force_utf8(vgrid_files_writable)) + shared_state["file_inotify"].add_watch( + force_utf8(vgrid_files_writable) + ) # logger.debug('(%s) Adding watch for: %s' % (pid, # vgrid_files_writable)) else: - logger.warning('(%s) file_monitor already exists for: %s' - % (pid, path)) + logger.warning("(%s) file_monitor already exists for: %s" % (pid, path)) return True @@ -1380,12 +1541,12 @@ def add_vgrid_file_monitor(configuration, vgrid_name, path): # NOTE: make sure cache entry always gets initialized before use vgrid_dir_cache[path] = vgrid_dir_cache.get(path, {}) - vgrid_dir_cache[path]['mtime'] = vgrid_dir_cache[path].get('mtime', 0) + vgrid_dir_cache[path]["mtime"] = vgrid_dir_cache[path].get("mtime", 0) try: add_vgrid_file_monitor_watch(configuration, path) - if vgrid_files_path_mtime != vgrid_dir_cache[path]['mtime']: + if vgrid_files_path_mtime != vgrid_dir_cache[path]["mtime"]: # Traverse dirs for subdirs created since last run @@ -1395,17 +1556,19 @@ def add_vgrid_file_monitor(configuration, vgrid_name, path): # Force utf8 everywhere to avoid encoding issues vgrid_sub_path = force_utf8(vgrid_sub_path) if not vgrid_sub_path in vgrid_dir_cache: - retval &= add_vgrid_file_monitor(configuration, - vgrid_name, - vgrid_sub_path) + retval &= add_vgrid_file_monitor( + configuration, vgrid_name, vgrid_sub_path + ) - vgrid_dir_cache[path]['mtime'] = vgrid_files_path_mtime + vgrid_dir_cache[path]["mtime"] = vgrid_files_path_mtime except OSError as exc: # If we get an OSError, src_path was most likely deleted # after os.path.exists check or somehow not accessible - logger.warning('(%s) add_vgrid_file_monitor failed on %s: %s' % - (pid, path, exc)) + logger.warning( + "(%s) add_vgrid_file_monitor failed on %s: %s" + % (pid, path, exc) + ) del vgrid_dir_cache[path] return False @@ -1426,8 +1589,7 @@ def add_vgrid_file_monitors(configuration, vgrid_name): for path in vgrid_dir_cache_keys: # Make sure we only have utf8 everywhere to avoid encoding issues path = force_utf8(path) - vgrid_files_path = os.path.join(configuration.vgrid_files_home, - path) + vgrid_files_path = os.path.join(configuration.vgrid_files_home, path) if os.path.exists(vgrid_files_path): add_vgrid_file_monitor(configuration, vgrid_name, path) else: @@ -1447,8 +1609,7 @@ def generate_vgrid_dir_cache(configuration, vgrid_base_path): pid = multiprocessing.current_process().pid - vgrid_path = os.path.join(configuration.vgrid_files_home, - vgrid_base_path) + vgrid_path = os.path.join(configuration.vgrid_files_home, vgrid_base_path) if vgrid_base_path not in dir_cache: dir_cache[vgrid_base_path] = {} @@ -1458,8 +1619,7 @@ def generate_vgrid_dir_cache(configuration, vgrid_base_path): # Add VGrid root to directory cache vgrid_dir_cache[vgrid_base_path] = {} - vgrid_dir_cache[vgrid_base_path]['mtime'] = \ - os.path.getmtime(vgrid_path) + vgrid_dir_cache[vgrid_base_path]["mtime"] = os.path.getmtime(vgrid_path) # logger.debug('(%s) Updating dir_cache %s: %s' % (pid, # vgrid_base_path, @@ -1467,14 +1627,15 @@ def generate_vgrid_dir_cache(configuration, vgrid_base_path): # Add VGrid subdirs to directory cache - for (root, dir_names, _) in walk(vgrid_path, followlinks=True): + for root, dir_names, _ in walk(vgrid_path, followlinks=True): for dir_name in dir_names: dir_path = os.path.join(root, dir_name) dir_cache_path = strip_base_dirs(dir_path) if dir_cache_path not in vgrid_dir_cache: vgrid_dir_cache[dir_cache_path] = {} - vgrid_dir_cache[dir_cache_path]['mtime'] = \ - os.path.getmtime(dir_path) + vgrid_dir_cache[dir_cache_path]["mtime"] = os.path.getmtime( + dir_path + ) # logger.debug('(%s) Updating dir_cache %s: %s' % (pid, # dir_cache_path, @@ -1491,10 +1652,10 @@ def load_dir_cache(configuration, vgrid_name): pid = multiprocessing.current_process().pid vgrid_home_path = os.path.join(configuration.vgrid_home, vgrid_name) - vgrid_dir_cache_filename = '.%s.dir_cache' \ - % configuration.vgrid_triggers - vgrid_dir_cache_filepath = os.path.join(vgrid_home_path, - vgrid_dir_cache_filename) + vgrid_dir_cache_filename = ".%s.dir_cache" % configuration.vgrid_triggers + vgrid_dir_cache_filepath = os.path.join( + vgrid_home_path, vgrid_dir_cache_filename + ) # logger.debug('(%s) loading dir cache for: %s from: %s' % (pid, # vgrid_name, vgrid_dir_cache_filename)) @@ -1505,8 +1666,9 @@ def load_dir_cache(configuration, vgrid_name): # cache_t1 = time.time() - loaded_dir_cache = unpickle(vgrid_dir_cache_filepath, logger, - allow_missing=False) + loaded_dir_cache = unpickle( + vgrid_dir_cache_filepath, logger, allow_missing=False + ) # cache_t2 = time.time() # logger.debug('(%s) Loaded vgrid_dir_cache for: %s in %s secs' @@ -1514,8 +1676,10 @@ def load_dir_cache(configuration, vgrid_name): if loaded_dir_cache is False: generate_cache = True - logger.error('(%s) Failed to load vgrid_dir_cache for: %s from file: %s' - % (pid, vgrid_name, vgrid_dir_cache_filepath)) + logger.error( + "(%s) Failed to load vgrid_dir_cache for: %s from file: %s" + % (pid, vgrid_name, vgrid_dir_cache_filepath) + ) else: generate_cache = False # TODO: once all caches are migrated we can remove this loop again @@ -1533,8 +1697,10 @@ def load_dir_cache(configuration, vgrid_name): generate_cache = True if generate_cache: - logger.info('(%s) Force generation of vgrid_dir_cache for: %s' % - (pid, vgrid_name)) + logger.info( + "(%s) Force generation of vgrid_dir_cache for: %s" + % (pid, vgrid_name) + ) # cache_t1 = time.time() @@ -1557,23 +1723,25 @@ def save_dir_cache(vgrid_name): result = True - dir_cache_filename = '.%s.dir_cache' % configuration.vgrid_triggers + dir_cache_filename = ".%s.dir_cache" % configuration.vgrid_triggers vgrid_dir_cache = dir_cache.get(vgrid_name, None) if vgrid_dir_cache is not None: - vgrid_home_path = os.path.join(configuration.vgrid_home, - vgrid_name) - dir_cache_filepath = os.path.join(vgrid_home_path, - dir_cache_filename) - vgrid_dir_cache_keys = [key for key in vgrid_dir_cache if key == vgrid_name - or key.startswith('%s%s' % (vgrid_name, - os.sep))] + vgrid_home_path = os.path.join(configuration.vgrid_home, vgrid_name) + dir_cache_filepath = os.path.join(vgrid_home_path, dir_cache_filename) + vgrid_dir_cache_keys = [ + key + for key in vgrid_dir_cache + if key == vgrid_name + or key.startswith("%s%s" % (vgrid_name, os.sep)) + ] if len(vgrid_dir_cache_keys) == 0: - logger.info('(%s) no dirs in cache for: %s' % (pid, - vgrid_name)) + logger.info("(%s) no dirs in cache for: %s" % (pid, vgrid_name)) else: - logger.info('(%s) saving cache for: %s to file: %s' % - (pid, vgrid_name, dir_cache_filepath)) + logger.info( + "(%s) saving cache for: %s to file: %s" + % (pid, vgrid_name, dir_cache_filepath) + ) pickle(vgrid_dir_cache, dir_cache_filepath, logger) return result @@ -1602,19 +1770,24 @@ def monitor(configuration, vgrid_name): # TODO: We loose access to logger when called through multiprocessing - print('Starting monitor process with PID: %s for vgrid: %s' % (pid, - vgrid_name)) - logger.info('Starting monitor process with PID: %s for vgrid: %s' - % (pid, vgrid_name)) + print( + "Starting monitor process with PID: %s for vgrid: %s" + % (pid, vgrid_name) + ) + logger.info( + "Starting monitor process with PID: %s for vgrid: %s" + % (pid, vgrid_name) + ) # Set base directories and appropriate lengths - shared_state['base_dir'] = os.path.join(configuration.vgrid_files_home) - shared_state['base_dir_len'] = len(shared_state['base_dir']) + shared_state["base_dir"] = os.path.join(configuration.vgrid_files_home) + shared_state["base_dir_len"] = len(shared_state["base_dir"]) - shared_state['writable_dir'] = os.path.join( - configuration.vgrid_files_writable) - shared_state['writable_dir_len'] = len(shared_state['writable_dir']) + shared_state["writable_dir"] = os.path.join( + configuration.vgrid_files_writable + ) + shared_state["writable_dir_len"] = len(shared_state["writable_dir"]) # Allow e.g. logrotate to force log re-open after rotates register_hangup_handler(configuration) @@ -1622,108 +1795,130 @@ def monitor(configuration, vgrid_name): # Monitor rule configurations writeable_sub_vgrids = [] - if vgrid_name == '.': + if vgrid_name == ".": vgrid_home = configuration.vgrid_home - file_monitor_home = shared_state['base_dir'] - writable_dir = shared_state['writable_dir'] + file_monitor_home = shared_state["base_dir"] + writable_dir = shared_state["writable_dir"] recursive_rule_monitor = False else: vgrid_home = os.path.join(configuration.vgrid_home, vgrid_name) - file_monitor_home = os.path.join(shared_state['base_dir'], vgrid_name) - writable_dir = os.path.join(shared_state['writable_dir'], vgrid_name) + file_monitor_home = os.path.join(shared_state["base_dir"], vgrid_name) + writable_dir = os.path.join(shared_state["writable_dir"], vgrid_name) recursive_rule_monitor = True # Check for sub vgrids by checking for sub directories in vgrid_home - for (root, dirs, _) in walk(vgrid_home): + for root, dirs, _ in walk(vgrid_home): for dir_name in dirs: # Need to join then replace here to catch sub-sub vgrids - sub_vgrid = os.path.join(root, dir_name).replace( - configuration.vgrid_home, '').replace(os.path.sep, ':') + sub_vgrid = ( + os.path.join(root, dir_name) + .replace(configuration.vgrid_home, "") + .replace(os.path.sep, ":") + ) sub_vgrid_path = os.path.join( - configuration.vgrid_files_writable, sub_vgrid) - if not sub_vgrid.startswith('.') \ - and os.path.exists(sub_vgrid_path): + configuration.vgrid_files_writable, sub_vgrid + ) + if not sub_vgrid.startswith(".") and os.path.exists( + sub_vgrid_path + ): writeable_sub_vgrids.append(sub_vgrid) if writeable_sub_vgrids: - msg = 'Within vgrid %s, Found sub vgrids: %s' \ - % (vgrid_name, writeable_sub_vgrids) + msg = "Within vgrid %s, Found sub vgrids: %s" % ( + vgrid_name, + writeable_sub_vgrids, + ) print(msg) logger.info(msg) rule_monitor = Observer() - rule_patterns = [os.path.join(vgrid_home, '*')] - shared_state['rule_handler'] = MiGRuleEventHandler( - patterns=rule_patterns, ignore_directories=False, case_sensitive=True) - - rule_monitor.schedule(shared_state['rule_handler'], vgrid_home, - recursive=recursive_rule_monitor) + rule_patterns = [os.path.join(vgrid_home, "*")] + shared_state["rule_handler"] = MiGRuleEventHandler( + patterns=rule_patterns, ignore_directories=False, case_sensitive=True + ) + + rule_monitor.schedule( + shared_state["rule_handler"], + vgrid_home, + recursive=recursive_rule_monitor, + ) rule_monitor.start() if len(rule_monitor._emitters) != 1: - logger.error('(%s) Number of rule_monitor._emitters != 1' % pid) + logger.error("(%s) Number of rule_monitor._emitters != 1" % pid) return 1 rule_monitor_emitter = min(rule_monitor._emitters) - if not hasattr(rule_monitor_emitter, '_inotify'): - logger.error('(%s) rule_monitor_emitter require inotify' % pid) + if not hasattr(rule_monitor_emitter, "_inotify"): + logger.error("(%s) rule_monitor_emitter require inotify" % pid) return 1 - shared_state['rule_inotify'] = rule_monitor_emitter._inotify._inotify + shared_state["rule_inotify"] = rule_monitor_emitter._inotify._inotify - logger.info('(%s) initializing file listener - may take some time' - % pid) + logger.info("(%s) initializing file listener - may take some time" % pid) # monitor actual files to handle events for vgrid_files_home file_monitor = Observer() file_patterns = [ - os.path.join(file_monitor_home, '*'), - os.path.join(writable_dir, '*') + os.path.join(file_monitor_home, "*"), + os.path.join(writable_dir, "*"), ] for sub_vgrid in writeable_sub_vgrids: - sub_vgrid_home = os.path.join(configuration.vgrid_files_home, - sub_vgrid.replace(':', os.path.sep), ) + sub_vgrid_home = os.path.join( + configuration.vgrid_files_home, + sub_vgrid.replace(":", os.path.sep), + ) if os.path.exists(sub_vgrid_home): - file_patterns.append(os.path.join(sub_vgrid_home, '*')) + file_patterns.append(os.path.join(sub_vgrid_home, "*")) - sub_vgrid_writeable = os.path.join(configuration.vgrid_files_writable, - sub_vgrid) + sub_vgrid_writeable = os.path.join( + configuration.vgrid_files_writable, sub_vgrid + ) if os.path.exists(sub_vgrid_writeable): - file_patterns.append(os.path.join(sub_vgrid_writeable, '*')) + file_patterns.append(os.path.join(sub_vgrid_writeable, "*")) - logger.info('(%s) initializing listener with patterns: %s' - % (pid, file_patterns)) + logger.info( + "(%s) initializing listener with patterns: %s" % (pid, file_patterns) + ) - shared_state['file_handler'] = MiGFileEventHandler( - patterns=file_patterns, ignore_directories=False, case_sensitive=True, - sub_vgrids=writeable_sub_vgrids) + shared_state["file_handler"] = MiGFileEventHandler( + patterns=file_patterns, + ignore_directories=False, + case_sensitive=True, + sub_vgrids=writeable_sub_vgrids, + ) - vgrid_homes = [os.path.join(configuration.vgrid_files_writable, sub_vgrid) - for sub_vgrid in writeable_sub_vgrids] + vgrid_homes = [ + os.path.join(configuration.vgrid_files_writable, sub_vgrid) + for sub_vgrid in writeable_sub_vgrids + ] vgrid_homes.append(file_monitor_home) for monitor_home in vgrid_homes: - logger.info('(%s) starting observer for: %s' % (pid, monitor_home)) + logger.info("(%s) starting observer for: %s" % (pid, monitor_home)) file_monitor = Observer() - file_monitor.schedule(shared_state['file_handler'], monitor_home, - recursive=False) + file_monitor.schedule( + shared_state["file_handler"], monitor_home, recursive=False + ) file_monitor.start() if len(file_monitor._emitters) != 1: - logger.error('(%s) Number of file_monitor._emitters != 1' % pid) + logger.error("(%s) Number of file_monitor._emitters != 1" % pid) return 1 file_monitor_emitter = min(file_monitor._emitters) - if not hasattr(file_monitor_emitter, '_inotify'): - logger.error('(%s) file_monitor require inotify' % pid) + if not hasattr(file_monitor_emitter, "_inotify"): + logger.error("(%s) file_monitor require inotify" % pid) return 1 - shared_state['file_inotify'] = file_monitor_emitter._inotify._inotify + shared_state["file_inotify"] = file_monitor_emitter._inotify._inotify - logger.info('(%s) trigger rule refresh for: %s' % (pid, vgrid_name)) + logger.info("(%s) trigger rule refresh for: %s" % (pid, vgrid_name)) # Fake touch event on all rule files to load initial rules - logger.info('(%s) trigger load on all rule files (greedy) for: %s matching %s' - % (pid, vgrid_name, rule_patterns[0])) + logger.info( + "(%s) trigger load on all rule files (greedy) for: %s matching %s" + % (pid, vgrid_name, rule_patterns[0]) + ) # We manually walk and test to get the greedy "*" directory match behaviour # of the PatternMatchingEventHandler @@ -1731,10 +1926,9 @@ def monitor(configuration, vgrid_name): all_trigger_rules = [] if recursive_rule_monitor: - for (root, _, files) in walk(vgrid_home): + for root, _, files in walk(vgrid_home): if configuration.vgrid_triggers in files: - rule_path = os.path.join(root, - configuration.vgrid_triggers) + rule_path = os.path.join(root, configuration.vgrid_triggers) all_trigger_rules.append(rule_path) else: for ent in scandir(vgrid_home): @@ -1747,13 +1941,13 @@ def monitor(configuration, vgrid_name): # logger.debug('(%s) trigger load on rules in %s' % (pid, # rule_path)) - shared_state['rule_handler'].dispatch(FileModifiedEvent(rule_path)) + shared_state["rule_handler"].dispatch(FileModifiedEvent(rule_path)) # logger.debug('(%s) loaded initial rules:\n%s' % (pid, all_rules)) # Add watches for directories - if vgrid_name == '.': + if vgrid_name == ".": # logger.debug('(%s) Skipping dir_cache load for root dir: %s' # % (pid, vgrid_name)) @@ -1771,8 +1965,10 @@ def monitor(configuration, vgrid_name): # - load_dir_cache_t1)) if not load_status: - logger.error('(%s) Failed to load / generate dir cache for: %s' - % (pid, vgrid_name)) + logger.error( + "(%s) Failed to load / generate dir cache for: %s" + % (pid, vgrid_name) + ) stop_running() activated = False @@ -1787,21 +1983,26 @@ def monitor(configuration, vgrid_name): if not activated: if active_targets(configuration, vgrid_name, file_monitor_home): # Start paths in vgrid_dir_cache to monitor - print('(%s) init trigger handling for: %s' % (pid, vgrid_name)) + print("(%s) init trigger handling for: %s" % (pid, vgrid_name)) add_monitor_t1 = time.time() add_vgrid_file_monitors(configuration, vgrid_name) add_monitor_t2 = time.time() - print('(%s) ready to handle triggers for: %s in %s secs' - % (pid, vgrid_name, add_monitor_t2 - add_monitor_t1)) - logger.info('(%s) ready to handle triggers for: %s in %s secs' - % (pid, vgrid_name, add_monitor_t2 - - add_monitor_t1)) + print( + "(%s) ready to handle triggers for: %s in %s secs" + % (pid, vgrid_name, add_monitor_t2 - add_monitor_t1) + ) + logger.info( + "(%s) ready to handle triggers for: %s in %s secs" + % (pid, vgrid_name, add_monitor_t2 - add_monitor_t1) + ) activated = True else: # Variable per-process delay to avoid thrashing delay = 60 + pid % 30 - logger.debug('(%s) no matching triggers for %s - sleep %ds' % - (pid, vgrid_name, delay)) + logger.debug( + "(%s) no matching triggers for %s - sleep %ds" + % (pid, vgrid_name, delay) + ) time.sleep(delay) # Once past the activation we just sleep in a responsive loop @@ -1812,35 +2013,33 @@ def monitor(configuration, vgrid_name): time.sleep(1) except KeyboardInterrupt: - print('(%s) caught interrupt' % pid) - logger.info('(%s) caught interrupt' % pid) + print("(%s) caught interrupt" % pid) + logger.info("(%s) caught interrupt" % pid) stop_running() # Only save cache if rules were actually activated so dirs were monitored if activated: - print('(%s) Saving cache for vgrid: %s' % (pid, vgrid_name)) - logger.info('(%s) Saving cache for vgrid: %s' % (pid, vgrid_name)) + print("(%s) Saving cache for vgrid: %s" % (pid, vgrid_name)) + logger.info("(%s) Saving cache for vgrid: %s" % (pid, vgrid_name)) save_dir_cache(vgrid_name) - print('(%s) Exiting monitor for vgrid: %s' % (pid, vgrid_name)) - logger.info('(%s) Exiting for vgrid: %s' % (pid, vgrid_name)) + print("(%s) Exiting monitor for vgrid: %s" % (pid, vgrid_name)) + logger.info("(%s) Exiting for vgrid: %s" % (pid, vgrid_name)) return 0 -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', - 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger('events', configuration.user_events_log, - log_level) + logger = daemon_logger("events", configuration.user_events_log, log_level) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -1855,38 +2054,39 @@ def monitor(configuration, vgrid_name): print(err_msg) sys.exit(1) - print('''This is the MiG event handler daemon which monitors VGrid files + print("""This is the MiG event handler daemon which monitors VGrid files and triggers any configured events when target files are created, modifed or deleted. VGrid owners can configure rules to trigger such events based on file changes. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') +""") main_pid = os.getpid() - print('Starting Event handler daemon - Ctrl-C to quit') - logger.info('(%s) Starting Event handler daemon' % main_pid) + print("Starting Event handler daemon - Ctrl-C to quit") + logger.info("(%s) Starting Event handler daemon" % main_pid) vgrid_monitors = {} # Start monitor for new/removed vgrids - vgrid_name = '.' - vgrid_monitors[vgrid_name] = \ - multiprocessing.Process(target=monitor, args=(configuration, - vgrid_name)) + vgrid_name = "." + vgrid_monitors[vgrid_name] = multiprocessing.Process( + target=monitor, args=(configuration, vgrid_name) + ) # Each top vgrid gets is own process for ent in scandir(configuration.vgrid_home): - vgrid_files_path = os.path.join(configuration.vgrid_files_home, - ent.name) + vgrid_files_path = os.path.join( + configuration.vgrid_files_home, ent.name + ) if os.path.isdir(ent.path) and os.path.isdir(vgrid_files_path): vgrid_name = ent.name - vgrid_monitors[vgrid_name] = \ - multiprocessing.Process(target=monitor, - args=(configuration, vgrid_name)) + vgrid_monitors[vgrid_name] = multiprocessing.Process( + target=monitor, args=(configuration, vgrid_name) + ) # else: # logger.debug('Skipping _NON_ vgrid: %s' % ent.path) @@ -1894,7 +2094,7 @@ def monitor(configuration, vgrid_name): for monitor in vgrid_monitors.values(): monitor.start() - logger.debug('(%s) Starting main loop' % main_pid) + logger.debug("(%s) Starting main loop" % main_pid) print("%s: Start main loop" % os.getpid()) while not check_stop(): try: @@ -1907,31 +2107,33 @@ def monitor(configuration, vgrid_name): # NOTE: we can't be sure if SIGINT was sent to only main process # so we make sure to propagate to all monitor children print("Interrupt requested - close monitors and shutdown") - logger.info('(%s) Shut down monitors and wait' % os.getpid()) + logger.info("(%s) Shut down monitors and wait" % os.getpid()) for monitor in vgrid_monitors.values(): mon_pid = monitor.pid if mon_pid is None: continue - logger.debug('send exit signal to monitor %s' % mon_pid) + logger.debug("send exit signal to monitor %s" % mon_pid) os.kill(mon_pid, signal.SIGINT) - logger.info('Wait for monitors to clean up') + logger.info("Wait for monitors to clean up") for monitor in vgrid_monitors.values(): mon_pid = monitor.pid - logger.debug('wait for monitor %s: %s' % (mon_pid, - monitor.is_alive())) + logger.debug( + "wait for monitor %s: %s" % (mon_pid, monitor.is_alive()) + ) monitor.join(5) if monitor.is_alive(): - logger.warning("force kill %s: %s" % (mon_pid, - monitor.is_alive())) + logger.warning( + "force kill %s: %s" % (mon_pid, monitor.is_alive()) + ) monitor.terminate() else: - logger.debug('monitor %s: done' % mon_pid) + logger.debug("monitor %s: done" % mon_pid) - logger.info('(%s) Shut down: all monitors done' % os.getpid()) + logger.info("(%s) Shut down: all monitors done" % os.getpid()) print("All monitors finished shutting down") - print('Event handler daemon shutting down') - logger.info('(%s) Event handler daemon shutting down' % main_pid) + print("Event handler daemon shutting down") + logger.info("(%s) Event handler daemon shutting down" % main_pid) sys.exit(0) diff --git a/mig/server/grid_imnotify.py b/mig/server/grid_imnotify.py index c4144d399..aa12391ad 100755 --- a/mig/server/grid_imnotify.py +++ b/mig/server/grid_imnotify.py @@ -49,16 +49,16 @@ authors. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function from future import standard_library + standard_library.install_aliases() -from builtins import range +import _thread import os import sys import time -import _thread +from builtins import range try: import irclib @@ -70,11 +70,11 @@ getting_buddy_list = False protocol_online_dict = { - 'jabber': False, - 'msn': False, - 'yahoo': False, - 'icq': False, - 'aol': False, + "jabber": False, + "msn": False, + "yahoo": False, + "icq": False, + "aol": False, } nick_and_id_dict = {} @@ -89,7 +89,7 @@ def send_msg( ): """Send IM request through connection""" - print('send msg called') + print("send msg called") global getting_buddy_list global latest_contact_online global protocol_online_dict @@ -101,15 +101,17 @@ def send_msg( got_online = False for _ in range(30): if not protocol_online_dict[im_network]: - print('waiting for protocol %s to get online (status for all protocols: %s)' - % (im_network, protocol_online_dict)) + print( + "waiting for protocol %s to get online (status for all protocols: %s)" + % (im_network, protocol_online_dict) + ) time.sleep(2) else: got_online = True break if not got_online: - raise Exception('gave up waiting to get online') + raise Exception("gave up waiting to get online") # Fetch buddy list and let any exceptions pass to caller @@ -120,62 +122,69 @@ def send_msg( if not getting_buddy_list: got_buddy_list = True break - connection.privmsg('root', 'blist all') + connection.privmsg("root", "blist all") for _ in range(30): if not getting_buddy_list: break - print('waiting while buddy list is generated') + print("waiting while buddy list is generated") time.sleep(2) if not got_buddy_list: - raise Exception('gave up waiting for buddy list') + raise Exception("gave up waiting for buddy list") replaced_im_network = im_network - if replaced_im_network == 'aol': - dest += '@login.oscar.aol.com' - replaced_im_network = 'osc' - elif replaced_im_network == 'icq': - dest += '@login.icq.com' - replaced_im_network = 'osc' - if im_network == 'yahoo': - if not dest.endswith('@yahoo'): - dest += '@yahoo' - print('looking for %s_%s in %s' % (replaced_im_network, dest, - nick_and_id_dict)) - if '%s_%s' % (replaced_im_network, dest) in nick_and_id_dict: + if replaced_im_network == "aol": + dest += "@login.oscar.aol.com" + replaced_im_network = "osc" + elif replaced_im_network == "icq": + dest += "@login.icq.com" + replaced_im_network = "osc" + if im_network == "yahoo": + if not dest.endswith("@yahoo"): + dest += "@yahoo" + print( + "looking for %s_%s in %s" + % (replaced_im_network, dest, nick_and_id_dict) + ) + if "%s_%s" % (replaced_im_network, dest) in nick_and_id_dict: # nick was found in buddy dict. Get the nickname. - id_dict = nick_and_id_dict['%s_%s' % (replaced_im_network, - dest)] - print('account %s_%s found in buddy list: %s' % ( - replaced_im_network, dest, id_dict)) - nickname = id_dict['nick'] + id_dict = nick_and_id_dict["%s_%s" % (replaced_im_network, dest)] + print( + "account %s_%s found in buddy list: %s" + % (replaced_im_network, dest, id_dict) + ) + nickname = id_dict["nick"] else: # nick was not found in buddy dict, add user - print('account %s_%s not found in buddy list, adding..' - % (im_network, dest)) + print( + "account %s_%s not found in buddy list, adding.." + % (im_network, dest) + ) # Get protocol ID (called account) account_number = get_account_number(im_network) - all_nicks = [i['nick'] for i in nick_and_id_dict.values()] + all_nicks = [i["nick"] for i in nick_and_id_dict.values()] # assign unique local nick: len does not always yield highest # nickname index as e.g. illegal addresses won't get # permanently inserted. Increment index until unique! id_index = len(nick_and_id_dict) - nickname = 'nick%d' % id_index + nickname = "nick%d" % id_index while nickname in all_nicks: id_index += 1 - nickname = 'nick%d' % id_index + nickname = "nick%d" % id_index - print('assigned local nick %s to new user %s with %d nicks' - % (nickname, dest, len(nick_and_id_dict))) + print( + "assigned local nick %s to new user %s with %d nicks" + % (nickname, dest, len(nick_and_id_dict)) + ) # give contact a second to get online @@ -183,14 +192,15 @@ def send_msg( # add contact - print('add %s %s %s' % (account_number, dest, nickname)) - connection.privmsg('root', 'add %s %s %s' % (account_number, - dest, nickname)) + print("add %s %s %s" % (account_number, dest, nickname)) + connection.privmsg( + "root", "add %s %s %s" % (account_number, dest, nickname) + ) time.sleep(2) # actually send the message - for m in msg.split('
'): + for m in msg.split("
"): connection.privmsg(nickname, m) # sleep a bit to keep messages in correct order @@ -199,26 +209,26 @@ def send_msg( def on_connect(connection, event): - print('on_connect') + print("on_connect") if irclib.is_channel(target): connection.join(target) else: - print('target should be a channel!') + print("target should be a channel!") def get_account_number(im_network): # TODO: automatically get account numbers by calling and parsing an "account list" call - if im_network == 'msn': + if im_network == "msn": return 0 - elif im_network == 'jabber': + elif im_network == "jabber": return 4 - elif im_network == 'yahoo': + elif im_network == "yahoo": return 1 - elif im_network == 'icq': + elif im_network == "icq": return 2 - elif im_network == 'aol': + elif im_network == "aol": return 3 @@ -226,7 +236,7 @@ def on_privmsg(connection, event): global protocol_online_dict global nick_and_id_dict - if event.source() != 'root!root@%s' % server: + if event.source() != "root!root@%s" % server: # message should never be accepted if it is not sent by "root" @@ -235,54 +245,54 @@ def on_privmsg(connection, event): recvd = event.arguments()[0] recvd_split = recvd.split() - if recvd == 'Password accepted': + if recvd == "Password accepted": pass - elif recvd.startswith('msn - Logging in: Logged in'): - protocol_online_dict['msn'] = True - elif recvd.startswith('jabber - Logging in: Logged in'): - protocol_online_dict['jabber'] = True - elif recvd.startswith('YAHOO - Logged in'): - protocol_online_dict['yahoo'] = True - elif recvd.startswith('ICQ(275655718) - Logged in'): - protocol_online_dict['icq'] = True - elif recvd.startswith('TOC(migdaemon) - Logged in'): - protocol_online_dict['aol'] = True - elif len(recvd_split) >= 4 and recvd_split[1].find('@') >= 0: + elif recvd.startswith("msn - Logging in: Logged in"): + protocol_online_dict["msn"] = True + elif recvd.startswith("jabber - Logging in: Logged in"): + protocol_online_dict["jabber"] = True + elif recvd.startswith("YAHOO - Logged in"): + protocol_online_dict["yahoo"] = True + elif recvd.startswith("ICQ(275655718) - Logged in"): + protocol_online_dict["icq"] = True + elif recvd.startswith("TOC(migdaemon) - Logged in"): + protocol_online_dict["aol"] = True + elif len(recvd_split) >= 4 and recvd_split[1].find("@") >= 0: # "blist all" reply. Create a small dict containing info about this single contact # TODO: make the if check more specific to be sure wrong messages are never accepted # recvd_split[2] is on the form: jabber(mig_daemon@jab im_network_tmp_split = recvd_split[2].split( - '(') # rstrip(")").lstrip("(").lower() # (YAHOO) -> yahoo + "(" + ) # rstrip(")").lstrip("(").lower() # (YAHOO) -> yahoo im_network = im_network_tmp_split[0] - if im_network.startswith('osc'): - im_network = 'osc' + if im_network.startswith("osc"): + im_network = "osc" im_id = recvd_split[1] # henrik_karlsen@hotmail.com id_dict = {} - id_dict['nick'] = recvd_split[0] # henrik_karlsen - id_dict['status'] = recvd_split[3] # (Online) (verify format) + id_dict["nick"] = recvd_split[0] # henrik_karlsen + id_dict["status"] = recvd_split[3] # (Online) (verify format) # unique id is im_network_im_id, eg. msn_henrik_karlsen@hotmail.com - nick_and_id_dict['%s_%s' % (im_network, im_id.lower())] = \ - id_dict - print('new dict entry: %s_%s' % (im_network, im_id.lower())) + nick_and_id_dict["%s_%s" % (im_network, im_id.lower())] = id_dict + print("new dict entry: %s_%s" % (im_network, im_id.lower())) elif len(recvd_split) > 2: - if recvd_split[1] == 'buddies': + if recvd_split[1] == "buddies": # end of buddy list global getting_buddy_list getting_buddy_list = False - print('buddy list end..') + print("buddy list end..") else: - print('Unknown message: %s' % recvd) + print("Unknown message: %s" % recvd) def on_pubmsg(connection, event): - print('pubmsg: %s' % event.arguments()[0]) + print("pubmsg: %s" % event.arguments()[0]) def on_join(connection, event): @@ -293,12 +303,11 @@ def on_join(connection, event): # login to bitlbee - login_msg = 'identify %s' % bitlbee_password - print('identify **REDACTED**') - connection.privmsg('root', login_msg) + login_msg = "identify %s" % bitlbee_password + print("identify **REDACTED**") + connection.privmsg("root", login_msg) else: - print('someone joined channel: %s' - % irclib.nm_to_n(event.source())) + print("someone joined channel: %s" % irclib.nm_to_n(event.source())) def on_disconnect(connection, event): @@ -309,17 +318,18 @@ def irc_process_forever(*args): irc.process_forever() -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger("imnotify", configuration.user_imnotify_log, - log_level) + logger = daemon_logger( + "imnotify", configuration.user_imnotify_log, log_level + ) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -347,12 +357,17 @@ def irc_process_forever(*args): server = configuration.user_imnotify_address port = configuration.user_imnotify_port - target = '#%s' % configuration.user_imnotify_channel + target = "#%s" % configuration.user_imnotify_channel nickname = configuration.user_imnotify_username bitlbee_password = configuration.user_imnotify_password - if not server or not port or not target or not nickname or \ - not bitlbee_password: + if ( + not server + or not port + or not target + or not nickname + or not bitlbee_password + ): print(server, port, target, nickname) err_msg = "IM notify helper setup is incomplete in configuration!" logger.error(err_msg) @@ -373,49 +388,49 @@ def irc_process_forever(*args): try: if not os.path.exists(stdin_path): - print('creating im_notify input pipe %s' % stdin_path) + print("creating im_notify input pipe %s" % stdin_path) try: os.mkfifo(stdin_path) except Exception as err: - print('Could not create missing IM stdin pipe %s: %s' - % (stdin_path, err)) + print( + "Could not create missing IM stdin pipe %s: %s" + % (stdin_path, err) + ) except: - print('error opening IM stdin! %s' % sys.exc_info()[0]) + print("error opening IM stdin! %s" % sys.exc_info()[0]) sys.exit(1) keep_running = True - print('Starting Real IM daemon - Ctrl-C to quit') + print("Starting Real IM daemon - Ctrl-C to quit") - print('Reading commands from %s' % stdin_path) + print("Reading commands from %s" % stdin_path) try: - im_notify_stdin = open(stdin_path, 'r') + im_notify_stdin = open(stdin_path, "r") except KeyboardInterrupt: keep_running = False except Exception as err: - print('could not open IM stdin %s, exception: %s' % (stdin_path, - err)) + print("could not open IM stdin %s, exception: %s" % (stdin_path, err)) sys.exit(1) while keep_running: try: if not irc: - print('Initialising IRC access to %s' % server) + print("Initialising IRC access to %s" % server) irc = irclib.IRC() try: - irc_server = irc.server().connect(server, port, - nickname) + irc_server = irc.server().connect(server, port, nickname) except irclib.ServerConnectionError as exc: - print('Could not connect to irc server: %s' % exc) + print("Could not connect to irc server: %s" % exc) irc = None time.sleep(30) continue - irc_server.add_global_handler('connect', on_connect) - irc_server.add_global_handler('join', on_join) - irc_server.add_global_handler('disconnect', on_disconnect) - irc_server.add_global_handler('privmsg', on_privmsg) - irc_server.add_global_handler('pubmsg', on_pubmsg) + irc_server.add_global_handler("connect", on_connect) + irc_server.add_global_handler("join", on_join) + irc_server.add_global_handler("disconnect", on_disconnect) + irc_server.add_global_handler("privmsg", on_privmsg) + irc_server.add_global_handler("pubmsg", on_pubmsg) _thread.start_new_thread(irc_process_forever, ()) # Handle messages @@ -435,7 +450,7 @@ def irc_process_forever(*args): if not line or attempt >= max_retries: line = im_notify_stdin.readline() attempt = 0 - if line.upper().startswith('SENDMESSAGE '): + if line.upper().startswith("SENDMESSAGE "): # The received line should be on a format similar to: # SENDMESSAGE PROTOCOL TO MESSAGE ex: @@ -443,30 +458,33 @@ def irc_process_forever(*args): # split string - split_line = line.split(' ', 3) + split_line = line.split(" ", 3) if len(split_line) != 4: - print('received SENDMESSAGE not on correct format %s' - % line) + print( + "received SENDMESSAGE not on correct format %s" % line + ) continue protocol = split_line[1] recipient = split_line[2] message = split_line[3] - print('Sending message: protocol: %s to: %s message: %s' - % (protocol, recipient, message)) + print( + "Sending message: protocol: %s to: %s message: %s" + % (protocol, recipient, message) + ) send_msg(irc_server, recipient, protocol, message) - print('Message sent to %s' % recipient) - elif line.upper().startswith('SHOWBUDDIES'): - print('Buddy list:') - for (key, val) in nick_and_id_dict.items(): - print('%s:\n\t%s' % (key, val)) - print('-----') - elif line.upper().startswith('SHUTDOWN'): - print('--- SAFE SHUTDOWN INITIATED ---') + print("Message sent to %s" % recipient) + elif line.upper().startswith("SHOWBUDDIES"): + print("Buddy list:") + for key, val in nick_and_id_dict.items(): + print("%s:\n\t%s" % (key, val)) + print("-----") + elif line.upper().startswith("SHUTDOWN"): + print("--- SAFE SHUTDOWN INITIATED ---") break elif line: - print('unknown message received: %s' % line) + print("unknown message received: %s" % line) line = None # Throttle down @@ -475,9 +493,9 @@ def irc_process_forever(*args): except KeyboardInterrupt: keep_running = False except Exception as exc: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) irc = None attempt += 1 - print('Real IM daemon shutting down') + print("Real IM daemon shutting down") sys.exit(0) diff --git a/mig/server/grid_imnotify_stdout.py b/mig/server/grid_imnotify_stdout.py index 1f3f80092..ce9a83635 100755 --- a/mig/server/grid_imnotify_stdout.py +++ b/mig/server/grid_imnotify_stdout.py @@ -27,8 +27,7 @@ """Dummy IM daemon writing requests to stdout instead of sending them""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import os import sys @@ -39,18 +38,19 @@ configuration, logger = None, None -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) - print(os.environ.get('MIG_CONF', 'DEFAULT'), configuration.server_fqdn) + print(os.environ.get("MIG_CONF", "DEFAULT"), configuration.server_fqdn) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger("imnotify", configuration.user_imnotify_log, - log_level) + logger = daemon_logger( + "imnotify", configuration.user_imnotify_log, log_level + ) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -62,7 +62,7 @@ print(err_msg) sys.exit(1) - print('''This is a dummy MiG IM notification daemon which just prints all + print("""This is a dummy MiG IM notification daemon which just prints all requests. The real notification daemon, grid_imnotify.py, hard codes accounts and @@ -72,45 +72,47 @@ Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf - ''') + """) - print('Starting Dummy IM daemon - Ctrl-C to quit') + print("Starting Dummy IM daemon - Ctrl-C to quit") stdin_path = configuration.im_notify_stdin try: if not os.path.exists(stdin_path): - print('creating im_notify input pipe %s' % stdin_path) + print("creating im_notify input pipe %s" % stdin_path) try: os.mkfifo(stdin_path) except Exception as err: - print('Could not create missing IM stdin pipe %s: %s' - % (stdin_path, err)) + print( + "Could not create missing IM stdin pipe %s: %s" + % (stdin_path, err) + ) except: - print('error opening IM stdin! %s' % sys.exc_info()[0]) + print("error opening IM stdin! %s" % sys.exc_info()[0]) sys.exit(1) keep_running = True - print('Reading commands from %s' % stdin_path) + print("Reading commands from %s" % stdin_path) try: - im_notify_stdin = open(stdin_path, 'r') + im_notify_stdin = open(stdin_path, "r") except KeyboardInterrupt: keep_running = False except Exception as exc: - print('could not open IM stdin %s: %s' % (stdin_path, exc)) + print("could not open IM stdin %s: %s" % (stdin_path, exc)) sys.exit(1) while keep_running: try: line = im_notify_stdin.readline() - if line.upper().startswith('SENDMESSAGE '): + if line.upper().startswith("SENDMESSAGE "): print(line) - elif line.upper().startswith('SHUTDOWN'): - print('--- SAFE SHUTDOWN INITIATED ---') + elif line.upper().startswith("SHUTDOWN"): + print("--- SAFE SHUTDOWN INITIATED ---") break elif line: - print('unknown message received: %s' % line) + print("unknown message received: %s" % line) # Throttle down @@ -118,7 +120,7 @@ except KeyboardInterrupt: keep_running = False except Exception as exc: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) - print('Dummy IM daemon shutting down') + print("Dummy IM daemon shutting down") sys.exit(0) diff --git a/mig/server/grid_monitor.py b/mig/server/grid_monitor.py index 9b03bef27..d914576cd 100755 --- a/mig/server/grid_monitor.py +++ b/mig/server/grid_monitor.py @@ -27,8 +27,7 @@ """Creating the MiG monitor page""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import datetime import os @@ -39,8 +38,12 @@ from mig.shared.defaults import default_vgrid from mig.shared.fileio import unpickle from mig.shared.gridstat import GridStat -from mig.shared.htmlgen import get_xgi_html_header, get_xgi_html_footer, \ - themed_styles, themed_scripts +from mig.shared.htmlgen import ( + get_xgi_html_footer, + get_xgi_html_header, + themed_scripts, + themed_styles, +) from mig.shared.logger import daemon_logger, register_hangup_handler from mig.shared.output import format_timedelta from mig.shared.resource import anon_resource_id @@ -52,29 +55,32 @@ def create_monitor(vgrid_name): """Write monitor HTML file for vgrid_name""" - html_file = os.path.join(configuration.vgrid_home, vgrid_name, - '%s.html' % configuration.vgrid_monitor) + html_file = os.path.join( + configuration.vgrid_home, + vgrid_name, + "%s.html" % configuration.vgrid_monitor, + ) - print('collecting statistics for VGrid %s' % vgrid_name) + print("collecting statistics for VGrid %s" % vgrid_name) sleep_secs = configuration.sleep_secs slackperiod = configuration.slackperiod now = time.asctime(time.localtime()) html_vars = { - 'sleep_secs': sleep_secs, - 'vgrid_name': vgrid_name, - 'logo_url': '/images/logo.jpg', - 'now': now, - 'short_title': configuration.short_title, + "sleep_secs": sleep_secs, + "vgrid_name": vgrid_name, + "logo_url": "/images/logo.jpg", + "now": now, + "short_title": configuration.short_title, } - monitor_meta = ''' -''' % html_vars - add_import = ''' + monitor_meta = """ +""" % html_vars + add_import = """ - ''' - add_init = '' - add_ready = ''' + """ + add_init = "" + add_ready = """ // table initially sorted by col. 1 (name) var sortOrder = [[1,0]]; @@ -96,8 +102,8 @@ def create_monitor(vgrid_name): /* tablesorter chokes on empty tables - just continue */ } }); - ''' - monitor_js = ''' + """ + monitor_js = """ %s -''' % (add_import, add_init, add_ready) +""" % (add_import, add_init, add_ready) # User default site style style_helpers = themed_styles(configuration) script_helpers = themed_scripts(configuration) - script_helpers['advanced'] += add_import - script_helpers['init'] += add_init - script_helpers['ready'] += add_ready + script_helpers["advanced"] += add_import + script_helpers["init"] += add_init + script_helpers["ready"] += add_ready html = get_xgi_html_header( configuration, - '%(short_title)s Monitor, VGrid %(vgrid_name)s' % html_vars, - '', + "%(short_title)s Monitor, VGrid %(vgrid_name)s" % html_vars, + "", html=True, meta=monitor_meta, style_map=style_helpers, @@ -130,15 +136,13 @@ def create_monitor(vgrid_name): widgets=False, userstyle=False, ) - html += \ - ''' + html += """

Statistics/monitor for the %(vgrid_name)s VGrid

This page was generated %(now)s (automatic refresh every %(sleep_secs)s secs).
-'''\ - % html_vars +""" % html_vars # loop and get totals @@ -162,7 +166,7 @@ def create_monitor(vgrid_name): disk_done = 0 memory_requested = 0 memory_done = 0 - runtimeenv_dict = {'': 0} + runtimeenv_dict = {"": 0} runtimeenv_requested = 0 runtimeenv_done = 0 @@ -176,62 +180,70 @@ def create_monitor(vgrid_name): gstat = GridStat(configuration, logger) - runtimeenv_dict = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'RUNTIMEENVIRONMENT', {}) - - parse_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'PARSE') - queued_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'QUEUED') - frozen_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'FROZEN') - executing_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'EXECUTING') - failed_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'FAILED') - retry_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'RETRY') - canceled_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'CANCELED') - expired_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'EXPIRED') - finished_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'FINISHED') - - nodecount_requested = gstat.get_value(gstat.VGRID, - vgrid_name.upper(), 'NODECOUNT_REQ') - nodecount_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'NODECOUNT_DONE') - cputime_requested = gstat.get_value(gstat.VGRID, - vgrid_name.upper(), 'CPUTIME_REQ') - cputime_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'CPUTIME_DONE') - - used_walltime = gstat.get_value(gstat.VGRID, - vgrid_name.upper(), - 'USED_WALLTIME') - - if (used_walltime == 0): + runtimeenv_dict = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "RUNTIMEENVIRONMENT", {} + ) + + parse_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "PARSE") + queued_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "QUEUED") + frozen_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "FROZEN") + executing_count = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "EXECUTING" + ) + failed_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "FAILED") + retry_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "RETRY") + canceled_count = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "CANCELED" + ) + expired_count = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "EXPIRED") + finished_count = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "FINISHED" + ) + + nodecount_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "NODECOUNT_REQ" + ) + nodecount_done = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "NODECOUNT_DONE" + ) + cputime_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "CPUTIME_REQ" + ) + cputime_done = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "CPUTIME_DONE" + ) + + used_walltime = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "USED_WALLTIME" + ) + + if used_walltime == 0: used_walltime = datetime.timedelta(0) used_walltime = format_timedelta(used_walltime) - disk_requested = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'DISK_REQ') - disk_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'DISK_DONE') - memory_requested = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'MEMORY_REQ') - memory_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'MEMORY_DONE') - cpucount_requested = gstat.get_value(gstat.VGRID, - vgrid_name.upper(), 'CPUCOUNT_REQ') - cpucount_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'CPUCOUNT_DONE') - runtimeenv_requested = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'RUNTIMEENVIRONMENT_REQ') - runtimeenv_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), - 'RUNTIMEENVIRONMENT_DONE') + disk_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "DISK_REQ" + ) + disk_done = gstat.get_value(gstat.VGRID, vgrid_name.upper(), "DISK_DONE") + memory_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "MEMORY_REQ" + ) + memory_done = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "MEMORY_DONE" + ) + cpucount_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "CPUCOUNT_REQ" + ) + cpucount_done = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "CPUCOUNT_DONE" + ) + runtimeenv_requested = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "RUNTIMEENVIRONMENT_REQ" + ) + runtimeenv_done = gstat.get_value( + gstat.VGRID, vgrid_name.upper(), "RUNTIMEENVIRONMENT_DONE" + ) number_of_jobs = parse_count number_of_jobs += queued_count @@ -244,33 +256,32 @@ def create_monitor(vgrid_name): number_of_jobs += retry_count html_vars = { - 'parse_count': parse_count, - 'queued_count': queued_count, - 'frozen_count': frozen_count, - 'executing_count': executing_count, - 'failed_count': failed_count, - 'retry_count': retry_count, - 'canceled_count': canceled_count, - 'expired_count': expired_count, - 'finished_count': finished_count, - 'number_of_jobs': number_of_jobs, - 'cpucount_requested': cpucount_requested, - 'cpucount_done': cpucount_done, - 'nodecount_requested': nodecount_requested, - 'nodecount_done': nodecount_done, - 'cputime_requested': cputime_requested, - 'cputime_done': cputime_done, - 'used_walltime': used_walltime, - 'disk_requested': disk_requested, - 'disk_done': disk_done, - 'memory_requested': memory_requested, - 'memory_done': memory_done, - 'runtimeenv_requested': runtimeenv_requested, - 'runtimeenv_done': runtimeenv_done, + "parse_count": parse_count, + "queued_count": queued_count, + "frozen_count": frozen_count, + "executing_count": executing_count, + "failed_count": failed_count, + "retry_count": retry_count, + "canceled_count": canceled_count, + "expired_count": expired_count, + "finished_count": finished_count, + "number_of_jobs": number_of_jobs, + "cpucount_requested": cpucount_requested, + "cpucount_done": cpucount_done, + "nodecount_requested": nodecount_requested, + "nodecount_done": nodecount_done, + "cputime_requested": cputime_requested, + "cputime_done": cputime_done, + "used_walltime": used_walltime, + "disk_requested": disk_requested, + "disk_done": disk_done, + "memory_requested": memory_requested, + "memory_done": memory_done, + "runtimeenv_requested": runtimeenv_requested, + "runtimeenv_done": runtimeenv_done, } - html += \ - """

Job Stats

""" html += stores - html += '\n
+ html += """

Job Stats

\n
@@ -304,20 +315,22 @@ def create_monitor(vgrid_name): # No runtimeenv requests - html += '\n' + html += "\n" else: for entry in runtimeenv_dict: - if not entry == '': - html += '\n' % \ - (entry, runtimeenv_dict[entry]) + if not entry == "": + html += "\n" % ( + entry, + runtimeenv_dict[entry], + ) total_number_of_exe_resources, total_number_of_store_resources = 0, 0 total_number_of_exe_cpus, total_number_of_store_gigs = 0, 0 - vgrid_name_list = vgrid_name.split('/') - current_dir = '' + vgrid_name_list = vgrid_name.split("/") + current_dir = "" - exes, stores = '', '' + exes, stores = "", "" for vgrid_name_part in vgrid_name_list: current_dir = os.path.join(current_dir, vgrid_name_part) abs_mon_dir = os.path.join(configuration.vgrid_home, current_dir) @@ -330,38 +343,39 @@ def create_monitor(vgrid_name): sorted_names.sort() for filename in sorted_names: # print filename - if filename.startswith('monitor_last_request_'): + if filename.startswith("monitor_last_request_"): # read last request helper file mon_file_name = os.path.join(abs_mon_dir, filename) - print('found ' + mon_file_name) + print("found " + mon_file_name) last_request_dict = unpickle(mon_file_name, logger) if not last_request_dict: - print('could not open and unpickle: ' - + mon_file_name) + print("could not open and unpickle: " + mon_file_name) continue - if 'CREATED_TIME' not in last_request_dict: - print('skip broken last request dict: ' - + mon_file_name) + if "CREATED_TIME" not in last_request_dict: + print("skip broken last request dict: " + mon_file_name) continue - difference = datetime.datetime.now()\ - - last_request_dict['CREATED_TIME'] + difference = ( + datetime.datetime.now() - last_request_dict["CREATED_TIME"] + ) days = "%d" % (difference.days) hours = "%d" % (difference.seconds // 3600) minutes = "%d" % ((difference.seconds % 3600) // 60) seconds = "%d" % ((difference.seconds % 60) % 60) - last_timetuple = last_request_dict['CREATED_TIME'].timetuple() + last_timetuple = last_request_dict["CREATED_TIME"].timetuple() - if 'CPUTIME' in last_request_dict: - cputime = last_request_dict['CPUTIME'] - elif 'cputime' in last_request_dict: - cputime = last_request_dict['cputime'] + if "CPUTIME" in last_request_dict: + cputime = last_request_dict["CPUTIME"] + elif "cputime" in last_request_dict: + cputime = last_request_dict["cputime"] else: - print('ERROR: last request does not contain cputime field!: %s' - % last_request_dict) + print( + "ERROR: last request does not contain cputime field!: %s" + % last_request_dict + ) continue try: @@ -370,23 +384,26 @@ def create_monitor(vgrid_name): try: cpusec = int(float(cputime)) except ValueError as verr: - print('ERROR: failed to parse cputime %s: %s' - % (cputime, verr)) + print( + "ERROR: failed to parse cputime %s: %s" + % (cputime, verr) + ) # Include execution delay guesstimate for strict fill # LRMS resources try: - delay = int(last_request_dict['EXECUTION_DELAY']) + delay = int(last_request_dict["EXECUTION_DELAY"]) except KeyError: delay = 0 except ValueError: delay = 0 - time_remaining = (last_request_dict['CREATED_TIME'] - + datetime.timedelta(seconds=cpusec) - + datetime.timedelta(seconds=delay))\ - - datetime.datetime.now() + time_remaining = ( + last_request_dict["CREATED_TIME"] + + datetime.timedelta(seconds=cpusec) + + datetime.timedelta(seconds=delay) + ) - datetime.datetime.now() days_rem = "%d" % (time_remaining.days) hours_rem = "%d" % (time_remaining.seconds // 3600) minutes_rem = "%d" % ((time_remaining.seconds % 3600) // 60) @@ -394,97 +411,137 @@ def create_monitor(vgrid_name): if time_remaining.days < -7: try: - print('removing: %s as we havent seen him for %s days.' - % (mon_file_name, abs(time_remaining).days)) + print( + "removing: %s as we havent seen him for %s days." + % (mon_file_name, abs(time_remaining).days) + ) os.remove(mon_file_name) except Exception as err: - print("could not remove: '%s' Error: %s" - % (mon_file_name, err)) + print( + "could not remove: '%s' Error: %s" + % (mon_file_name, err) + ) pass else: - unique_res_name_and_exe_list = \ - filename.split('monitor_last_request_', 1) + unique_res_name_and_exe_list = filename.split( + "monitor_last_request_", 1 + ) if cpusec == 0: - resource_status = 'unavailable' + resource_status = "unavailable" elif time_remaining.days < 0: # time_remaining.days < 0 means that we have passed the specified time time_rem_abs = abs(time_remaining) - if time_rem_abs.days == 0\ - and int(time_rem_abs.seconds)\ - < int(slackperiod): - resource_status = 'slack' + if time_rem_abs.days == 0 and int( + time_rem_abs.seconds + ) < int(slackperiod): + resource_status = "slack" slack_count = slack_count + 1 else: - resource_status = 'offline' + resource_status = "offline" down_count = down_count + 1 else: - resource_status = 'online' + resource_status = "online" up_count = up_count + 1 - exes += '' - exes += \ - ''\ + exes += "" + exes += ( + "" % resource_status + ) public_id = unique_res_name_and_exe_list[1] - if last_request_dict['RESOURCE_CONFIG'].get('ANONYMOUS', True): + if last_request_dict["RESOURCE_CONFIG"].get( + "ANONYMOUS", True + ): public_id = anon_resource_id(public_id) - public_name = last_request_dict['RESOURCE_CONFIG'].get( - 'PUBLICNAME', '') - resource_parts = public_id.split('_', 2) - resource_name = "%s" % \ - (resource_parts[0], resource_parts[0]) + public_name = last_request_dict["RESOURCE_CONFIG"].get( + "PUBLICNAME", "" + ) + resource_parts = public_id.split("_", 2) + resource_name = ( + "%s" + % (resource_parts[0], resource_parts[0]) + ) if public_name: resource_name += "
(alias %s)" % public_name else: resource_name += "
(no alias)" resource_name += "
%s" % resource_parts[1] - exes += '' % resource_name + exes += "" % resource_name last_asctime = time.asctime(last_timetuple) last_epoch = time.mktime(last_timetuple) - exes += '' % (days, hours, minutes, - seconds) - exes += '' - runtime_envs = last_request_dict['RESOURCE_CONFIG' - ]['RUNTIMEENVIRONMENT'] + exes += '" % ( + days, + hours, + minutes, + seconds, + ) + exes += "" + runtime_envs = last_request_dict["RESOURCE_CONFIG"][ + "RUNTIMEENVIRONMENT" + ] runtime_envs.sort() - re_list_text = ', '.join([i[0] for i in runtime_envs]) + re_list_text = ", ".join([i[0] for i in runtime_envs]) exes += '' % ( - re_list_text, len(runtime_envs)) - for req_name in ['CPUTIME', 'NODECOUNT', 'CPUCOUNT', - 'DISK', 'MEMORY', 'ARCHITECTURE']: - exes += '' % last_request_dict['RESOURCE_CONFIG'][req_name] - exes += '' % last_request_dict - - exes += '" + % last_request_dict["RESOURCE_CONFIG"][req_name] + ) + exes += ( + "" + % last_request_dict + ) + + exes += "' - - exes += '\n' - if last_request_dict['STATUS'] == 'Job assigned': + exes += "%sd, %sh, %sm, %ss" % ( + days_rem, + hours_rem, + minutes_rem, + seconds_rem, + ) + exes += "" + + exes += "\n" + if last_request_dict["STATUS"] == "Job assigned": job_assigned += 1 - job_assigned_cpus += int(last_request_dict['RESOURCE_CONFIG']['NODECOUNT']) * int( - last_request_dict['RESOURCE_CONFIG']['CPUCOUNT']) + job_assigned_cpus += int( + last_request_dict["RESOURCE_CONFIG"]["NODECOUNT"] + ) * int( + last_request_dict["RESOURCE_CONFIG"]["CPUCOUNT"] + ) total_number_of_exe_resources += 1 total_number_of_exe_cpus += int( - last_request_dict['RESOURCE_CONFIG']['NODECOUNT']) \ - * int(last_request_dict['RESOURCE_CONFIG']['CPUCOUNT']) - elif filename.startswith('monitor_last_status_'): + last_request_dict["RESOURCE_CONFIG"]["NODECOUNT"] + ) * int(last_request_dict["RESOURCE_CONFIG"]["CPUCOUNT"]) + elif filename.startswith("monitor_last_status_"): # store must be linked to this vgrid, not only parent vgrid: # inheritance only covers access, not automatic participation @@ -495,58 +552,65 @@ def create_monitor(vgrid_name): # read last resource action status file mon_file_name = os.path.join(abs_mon_dir, filename) - print('found ' + mon_file_name) + print("found " + mon_file_name) last_status_dict = unpickle(mon_file_name, logger) if not last_status_dict: - print('could not open and unpickle: ' - + mon_file_name) + print("could not open and unpickle: " + mon_file_name) continue - if 'CREATED_TIME' not in last_status_dict: - print('skip broken last request dict: ' - + mon_file_name) + if "CREATED_TIME" not in last_status_dict: + print("skip broken last request dict: " + mon_file_name) continue - difference = datetime.datetime.now()\ - - last_status_dict['CREATED_TIME'] + difference = ( + datetime.datetime.now() - last_status_dict["CREATED_TIME"] + ) days = "%d" % (difference.days) hours = "%d" % (difference.seconds // 3600) minutes = "%d" % ((difference.seconds % 3600) // 60) seconds = "%d" % ((difference.seconds % 60) % 60) - if last_status_dict['STATUS'] == 'stopped': - time_stopped = datetime.datetime.now() - \ - last_status_dict['CREATED_TIME'] + if last_status_dict["STATUS"] == "stopped": + time_stopped = ( + datetime.datetime.now() + - last_status_dict["CREATED_TIME"] + ) if time_stopped.days > 7: try: - print('removing: %s as we havent seen him for %s days.' - % (mon_file_name, abs(time_stopped).days)) + print( + "removing: %s as we havent seen him for %s days." + % (mon_file_name, abs(time_stopped).days) + ) os.remove(mon_file_name) except Exception as err: - print("could not remove: '%s' Error: %s" - % (mon_file_name, err)) + print( + "could not remove: '%s' Error: %s" + % (mon_file_name, err) + ) continue unique_res_name_and_store_list = filename.split( - 'monitor_last_status_', 1) - mount_point = last_status_dict.get('MOUNT_POINT', 'UNKNOWN') + "monitor_last_status_", 1 + ) + mount_point = last_status_dict.get("MOUNT_POINT", "UNKNOWN") is_live = os.path.ismount(mount_point) public_id = unique_res_name_and_store_list[1] - if last_status_dict['RESOURCE_CONFIG'].get('ANONYMOUS', True): + if last_status_dict["RESOURCE_CONFIG"].get("ANONYMOUS", True): public_id = anon_resource_id(public_id) vgrid_link = os.path.join( - configuration.vgrid_files_home, vgrid_name, public_id) - is_linked = (os.path.realpath(vgrid_link) == mount_point) + configuration.vgrid_files_home, vgrid_name, public_id + ) + is_linked = os.path.realpath(vgrid_link) == mount_point - total_disk = last_status_dict['RESOURCE_CONFIG']['DISK'] + total_disk = last_status_dict["RESOURCE_CONFIG"]["DISK"] free_disk, avail_disk, used_disk, used_percent = 0, 0, 0, 0 gig_bytes = 1.0 * 2**30 # Fall back status - show last action unless statvfs succeeds - last_status = last_status_dict['STATUS'] - last_timetuple = last_status_dict['CREATED_TIME'].timetuple() + last_status = last_status_dict["STATUS"] + last_timetuple = last_status_dict["CREATED_TIME"].timetuple() # These disk stats are slightly confusing but match 'df' # 'available' is the space that can actually be used so it @@ -554,66 +618,81 @@ def create_monitor(vgrid_name): try: disk_stats = os.statvfs(mount_point) - total_disk = disk_stats.f_bsize * disk_stats.f_blocks // \ - gig_bytes - avail_disk = disk_stats.f_bsize * disk_stats.f_bavail // \ - gig_bytes - free_disk = disk_stats.f_bsize * disk_stats.f_bfree // \ - gig_bytes + total_disk = ( + disk_stats.f_bsize * disk_stats.f_blocks // gig_bytes + ) + avail_disk = ( + disk_stats.f_bsize * disk_stats.f_bavail // gig_bytes + ) + free_disk = ( + disk_stats.f_bsize * disk_stats.f_bfree // gig_bytes + ) used_disk = total_disk - free_disk used_percent = used_disk * 100.0 / (avail_disk + used_disk) - last_status = 'checked' + last_status = "checked" last_timetuple = datetime.datetime.now().timetuple() days, hours, minutes, seconds = 0, 0, 0, 0 except OSError as ose: - print('could not stat mount point %s: %s' % - (mount_point, ose)) + print( + "could not stat mount point %s: %s" % (mount_point, ose) + ) is_live = False - if last_status_dict['STATUS'] == 'stopped': - resource_status = 'offline' + if last_status_dict["STATUS"] == "stopped": + resource_status = "offline" down_count = down_count + 1 - elif last_status_dict['STATUS'] == 'started': + elif last_status_dict["STATUS"] == "started": if is_live and is_linked: - resource_status = 'online' + resource_status = "online" up_count = up_count + 1 else: - resource_status = 'slack' + resource_status = "slack" down_count = down_count + 1 else: - resource_status = 'unknown' + resource_status = "unknown" - stores += '' - stores += \ - ''\ + stores += "" + stores += ( + "" % resource_status - public_name = last_status_dict['RESOURCE_CONFIG'].get( - 'PUBLICNAME', '') - resource_parts = public_id.split('_', 2) - resource_name = "%s" % \ - (resource_parts[0], resource_parts[0]) + ) + public_name = last_status_dict["RESOURCE_CONFIG"].get( + "PUBLICNAME", "" + ) + resource_parts = public_id.split("_", 2) + resource_name = ( + "%s" + % (resource_parts[0], resource_parts[0]) + ) if public_name: resource_name += "
(alias %s)" % public_name else: resource_name += "
(no alias)" resource_name += "
%s" % resource_parts[1] - stores += '' % resource_name + stores += "" % resource_name last_asctime = time.asctime(last_timetuple) last_epoch = time.mktime(last_timetuple) - stores += '' % (days, hours, minutes, - seconds) - stores += '' - stores += '' % total_disk - stores += '' % used_disk - stores += '' % avail_disk - stores += '' % used_percent - - stores += '' - - stores += '\n' + stores += '" % ( + days, + hours, + minutes, + seconds, + ) + stores += "" + stores += "" % total_disk + stores += "" % used_disk + stores += "" % avail_disk + stores += "" % used_percent + + stores += "" + + stores += "\n" total_number_of_store_resources += 1 total_number_of_store_gigs += total_disk @@ -647,7 +726,7 @@ def create_monitor(vgrid_name): """ html += exes - html += '\n
Job StateNumber of jobs
Parse%(parse_count)s
Queued%(queued_count)s
-
-
%s%s
%s%s
%s%s
%s
%s
' % \ - (last_epoch, last_asctime) - exes += '(%sd %sh %sm %ss ago)
' + vgrid_name + '
%s
%s
' % ( + last_epoch, + last_asctime, + ) + exes += "(%sd %sh %sm %ss ago)
" + vgrid_name + "%d%s%(STATUS)s%(CPUTIME)s' % resource_status - if 'unavailable' == resource_status: - exes += '-' - elif 'slack' == resource_status: - exes += 'Within slack period (%s < %s secs)'\ - % (time_rem_abs.seconds, slackperiod) - elif 'offline' == resource_status: - exes += 'down?' + re_list_text, + len(runtime_envs), + ) + for req_name in [ + "CPUTIME", + "NODECOUNT", + "CPUCOUNT", + "DISK", + "MEMORY", + "ARCHITECTURE", + ]: + exes += ( + "%s%(STATUS)s%(CPUTIME)s" % resource_status + if "unavailable" == resource_status: + exes += "-" + elif "slack" == resource_status: + exes += "Within slack period (%s < %s secs)" % ( + time_rem_abs.seconds, + slackperiod, + ) + elif "offline" == resource_status: + exes += "down?" else: - exes += '%sd, %sh, %sm, %ss'\ - % (days_rem, hours_rem, minutes_rem, - seconds_rem) - exes += '
%s%s
%s
%s %s
' % \ - (last_epoch, last_status, last_asctime) - stores += '(%sd %sh %sm %ss ago)
' + vgrid_name + '%d%d%d%d' % resource_status - stores += resource_status + '
%s
%s %s
' % ( + last_epoch, + last_status, + last_asctime, + ) + stores += "(%sd %sh %sm %ss ago)
" + vgrid_name + "%d%d%d%d" % resource_status + stores += resource_status + "
\n' + html += "
\n" html += """

Resource Storage

@@ -670,44 +749,46 @@ def create_monitor(vgrid_name):
\n' - - fill_helpers = {'total_number_of_exe_resources': total_number_of_exe_resources, - 'total_number_of_exe_cpus': total_number_of_exe_cpus, - 'total_number_of_store_resources': total_number_of_store_resources, - 'total_number_of_store_gigs': int(total_number_of_store_gigs), - 'up_count': up_count, 'down_count': down_count, - 'slack_count': slack_count, 'job_assigned': job_assigned, - 'job_assigned_cpus': job_assigned_cpus} + html += "\n\n" + + fill_helpers = { + "total_number_of_exe_resources": total_number_of_exe_resources, + "total_number_of_exe_cpus": total_number_of_exe_cpus, + "total_number_of_store_resources": total_number_of_store_resources, + "total_number_of_store_gigs": int(total_number_of_store_gigs), + "up_count": up_count, + "down_count": down_count, + "slack_count": slack_count, + "job_assigned": job_assigned, + "job_assigned_cpus": job_assigned_cpus, + } html += """

VGrid Totals

A total of %(total_number_of_exe_resources)s exe resources (%(total_number_of_exe_cpus)s cpu's) and %(total_number_of_store_resources)s store resources (%(total_number_of_store_gigs)s GB) joined this VGrid (%(up_count)s up, %(down_count)s down?, %(slack_count)s slack)
%(job_assigned)s exe resources (%(job_assigned_cpus)s cpu's) appear to be executing a job

""" % fill_helpers - html += \ - '' - html += get_xgi_html_footer(configuration, '') + html += "" + html += get_xgi_html_footer(configuration, "") try: - file_handle = open(html_file, 'w') + file_handle = open(html_file, "w") file_handle.write(html) file_handle.close() except Exception as exc: - print('Could not write monitor page %s: %s' % (html_file, exc)) + print("Could not write monitor page %s: %s" % (html_file, exc)) -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger("monitor", configuration.user_monitor_log, - log_level) + logger = daemon_logger("monitor", configuration.user_monitor_log, log_level) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -733,7 +814,7 @@ def create_monitor(vgrid_name): try: os.makedirs(default_vgrid_dir) except OSError as ose: - logger.error('Failed to create default VGrid home: %s' % ose) + logger.error("Failed to create default VGrid home: %s" % ose) keep_running = True while keep_running: @@ -743,20 +824,20 @@ def create_monitor(vgrid_name): # create global statistics ("") # vgrids_list.append("") - print('Updating cache.') + print("Updating cache.") grid_stat = GridStat(configuration, logger) grid_stat.update() for vgrid_name in vgrids_list: - print('creating monitor for vgrid: %s' % vgrid_name) + print("creating monitor for vgrid: %s" % vgrid_name) create_monitor(vgrid_name) - print('sleeping for %s seconds' % configuration.sleep_secs) + print("sleeping for %s seconds" % configuration.sleep_secs) time.sleep(float(configuration.sleep_secs)) except KeyboardInterrupt: keep_running = False except Exception as exc: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) time.sleep(10) - print('Monitor daemon shutting down') + print("Monitor daemon shutting down") sys.exit(0) diff --git a/mig/server/grid_notify.py b/mig/server/grid_notify.py index b74c114ef..780a4da18 100755 --- a/mig/server/grid_notify.py +++ b/mig/server/grid_notify.py @@ -29,26 +29,23 @@ flooding. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import range -import os import multiprocessing +import os import signal import sys import time +from builtins import range from datetime import datetime -from mig.shared.base import extract_field, expand_openid_alias +from mig.shared.base import expand_openid_alias, extract_field from mig.shared.conf import get_configuration_object from mig.shared.defaults import ignore_file_names -from mig.shared.fileio import unpickle, delete_file -from mig.shared.logger import daemon_logger, \ - register_hangup_handler +from mig.shared.fileio import delete_file, unpickle +from mig.shared.logger import daemon_logger, register_hangup_handler from mig.shared.notification import send_email - stop_running = multiprocessing.Event() notify_interval = 60 received_notifications = {} @@ -57,7 +54,7 @@ def stop_handler(sig, frame): """A simple signal handler to quit on Ctrl+C (SIGINT) in main""" # Print blank line to avoid mix with Ctrl-C line - print('') + print("") stop_running.set() @@ -72,12 +69,14 @@ def cleanup_notify_home(configuration, notified_users=[], timestamp=None): # Remove notification files for notified users for client_id in notified_users: - cleanup_files = received_notifications.get( - client_id, {}).get('files', []) + cleanup_files = received_notifications.get(client_id, {}).get( + "files", [] + ) if not cleanup_files: logger.error( "Expected _NON_ empty files list for client_id: '%s'" - % client_id) + % client_id + ) for filepath in cleanup_files: # logger.debug("Removing notification file: '%s'" % filepath) delete_file(filepath, logger) @@ -106,43 +105,45 @@ def send_notifications(configuration): logger = configuration.logger # logger.debug("send_notifications") result = [] - for (client_id, client_dict) in received_notifications.items(): - timestamp = client_dict.get('timestamp', 0) - - timestr = (datetime.fromtimestamp(timestamp) - ).strftime('%d/%m/%Y %H:%M:%S') - client_name = extract_field(client_id, 'full_name') - client_email = extract_field(client_id, 'email') - recipient = "%s <%s>" % (client_name, - client_email) + for client_id, client_dict in received_notifications.items(): + timestamp = client_dict.get("timestamp", 0) + + timestr = (datetime.fromtimestamp(timestamp)).strftime( + "%d/%m/%Y %H:%M:%S" + ) + client_name = extract_field(client_id, "full_name") + client_email = extract_field(client_id, "email") + recipient = "%s <%s>" % (client_name, client_email) total_events = 0 notify_message = "" - messages_dict = client_dict.get('messages', {}) - for (header, value) in messages_dict.items(): + messages_dict = client_dict.get("messages", {}) + for header, value in messages_dict.items(): if notify_message: notify_message += "\n\n" notify_message += "= %s =\n" % header - for (message, events) in value.items(): + for message, events in value.items(): notify_message += "#%s : %s\n" % (events, message) total_events += events - subject = "%s system notification: %s new events" % \ - (configuration.short_title, total_events) - notify_message = "Found %s new events since: %s\n\n" \ - % (total_events, timestr) \ + subject = "%s system notification: %s new events" % ( + configuration.short_title, + total_events, + ) + notify_message = ( + "Found %s new events since: %s\n\n" % (total_events, timestr) + notify_message + ) status = send_email( - recipient, - subject, - notify_message, - logger, - configuration) + recipient, subject, notify_message, logger, configuration + ) if status: - logger.info("Send email with %s events to: %s" - % (total_events, recipient)) + logger.info( + "Send email with %s events to: %s" % (total_events, recipient) + ) result.append(client_id) else: - logger.error("Failed to send email to: '%s', '%s'" % - (recipient, client_id)) + logger.error( + "Failed to send email to: '%s', '%s'" % (recipient, client_id) + ) return result @@ -156,7 +157,7 @@ def recv_notification(configuration, path): if not new_notification: logger.error("Failed to unpickle: %s" % path) return False - user_id = new_notification.get('user_id', '') + user_id = new_notification.get("user_id", "") # logger.debug("Received user_id: '%s'" % user_id) if not user_id: status = False @@ -164,42 +165,43 @@ def recv_notification(configuration, path): else: client_id = expand_openid_alias(user_id, configuration) # logger.debug("resolved client_id: '%s'" % client_id) - if not client_id or not extract_field(client_id, 'email'): + if not client_id or not extract_field(client_id, "email"): status = False - logger.error("Failed to resolve client_id from user_id: '%s'" - % user_id) + logger.error( + "Failed to resolve client_id from user_id: '%s'" % user_id + ) if status: - category = new_notification.get('category', []) + category = new_notification.get("category", []) # logger.debug("Received category: %s" % category) if not isinstance(category, list): status = False logger.error("Received category: %s must be a list" % category) if status: - logger.info("Received event: %s, from: '%s'" - % (category, client_id)) - new_timestamp = new_notification.get('timestamp') - message = new_notification.get('message', '') + logger.info("Received event: %s, from: '%s'" % (category, client_id)) + new_timestamp = new_notification.get("timestamp") + message = new_notification.get("message", "") # logger.debug("Received message: %s" % message) client_dict = received_notifications.get(client_id, {}) if not client_dict: received_notifications[client_id] = client_dict - files_list = client_dict.get('files', []) + files_list = client_dict.get("files", []) if not files_list: - client_dict['files'] = files_list + client_dict["files"] = files_list if path in files_list: logger.warning( - "Skipping previously received notification: '%s'" % path) + "Skipping previously received notification: '%s'" % path + ) else: files_list.append(path) - client_dict['timestamp'] = min( - client_dict.get('timestamp', sys.maxsize), - new_timestamp) - messages_dict = client_dict.get('messages', {}) + client_dict["timestamp"] = min( + client_dict.get("timestamp", sys.maxsize), new_timestamp + ) + messages_dict = client_dict.get("messages", {}) if not messages_dict: - client_dict['messages'] = messages_dict + client_dict["messages"] = messages_dict header = " ".join(category) if not header: - header = '* UNKNOWN *' + header = "* UNKNOWN *" body_dict = messages_dict.get(header, {}) if not body_dict: messages_dict[header] = body_dict @@ -226,9 +228,11 @@ def handle_notifications(configuration): recv_notification(configuration, abspath) notified_users = send_notifications(configuration) last_notification = time.time() - cleanup_notify_home(configuration, - notified_users=notified_users, - timestamp=last_notification - 84600) + cleanup_notify_home( + configuration, + notified_users=notified_users, + timestamp=last_notification - 84600, + ) received_notifications.clear() logger.debug("----- Sleeping %s seconds -----" % notify_interval) time.sleep(notify_interval) @@ -247,8 +251,10 @@ def unittest(configuration, emailaddr, delay): """Unit test for grid_notify.py""" signal.signal(signal.SIGINT, stop_handler) from mig.shared.notification import send_system_notification - print("Starting unittest: emailaddr: %s" % emailaddr - + ", delay: %s" % delay) + + print( + "Starting unittest: emailaddr: %s" % emailaddr + ", delay: %s" % delay + ) if delay > 0: print("Waiting %s secs before executing unit test" % delay) time.sleep(delay) @@ -259,34 +265,37 @@ def unittest(configuration, emailaddr, delay): for i in range(nr_debug_users): client_ids.append( "/C=DK/ST=NA/L=NA/O=NBI/OU=NA/CN=Grid Notify %i/emailAddress=%s" - % (i, emailaddr)) + % (i, emailaddr) + ) print("=============================") print("======= Starting test =======") print("=============================") for client_id in client_ids: for i in range(5): - for protocol in ['SFTP', 'WebDAVS']: + for protocol in ["SFTP", "WebDAVS"]: if stop_running.is_set(): return category = [protocol] msg = "__UNITTEST__: %s" % protocol - print("unittest: Sending notification: %s" - ", category: %s: %s" % (i, category, client_id)) - send_system_notification(client_id, - category, - msg, - configuration) - for event in ['Invalid password', 'Expired 2FA session']: + print( + "unittest: Sending notification: %s" + ", category: %s: %s" % (i, category, client_id) + ) + send_system_notification( + client_id, category, msg, configuration + ) + for event in ["Invalid password", "Expired 2FA session"]: if stop_running.is_set(): return category = [protocol, event] msg = "__UNITTEST__: %s" % client_id - print("unittest: Sending notification: %s" % i - + ", category: %s: %s" % (category, client_id)) - send_system_notification(client_id, - category, - msg, - configuration) + print( + "unittest: Sending notification: %s" % i + + ", category: %s: %s" % (category, client_id) + ) + send_system_notification( + client_id, category, msg, configuration + ) if __name__ == "__main__": @@ -297,11 +306,15 @@ def unittest(configuration, emailaddr, delay): emailaddr = None delay = 0 argpos = 1 - if sys.argv[argpos:] and sys.argv[argpos] \ - in ['debug', 'info', 'warning', 'error']: + if sys.argv[argpos:] and sys.argv[argpos] in [ + "debug", + "info", + "warning", + "error", + ]: log_level = sys.argv[argpos] argpos += 1 - if sys.argv[argpos:] and len(sys.argv[argpos].split('@')) == 2: + if sys.argv[argpos:] and len(sys.argv[argpos].split("@")) == 2: emailaddr = sys.argv[argpos] argpos += 1 if sys.argv[argpos:]: @@ -319,13 +332,14 @@ def unittest(configuration, emailaddr, delay): # Start unittest if requested if emailaddr: - unittest_proc = multiprocessing.Process(target=unittest, - args=(configuration, - emailaddr, - delay)) + unittest_proc = multiprocessing.Process( + target=unittest, args=(configuration, emailaddr, delay) + ) unittest_proc.start() - info_msg = "Starting unit test process: email: %s, delay: %s" \ - % (emailaddr, delay) + info_msg = "Starting unit test process: email: %s, delay: %s" % ( + emailaddr, + delay, + ) print(info_msg) logger.info("(%s) %s" % (unittest_proc.pid, info_msg)) @@ -341,21 +355,22 @@ def unittest(configuration, emailaddr, delay): print(err_msg) sys.exit(1) - print('''This is the MiG system notify daemon which notify users about system events. + print( + """This is the MiG system notify daemon which notify users about system events. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') +""" + ) main_pid = os.getpid() print("Starting notify daemon - Ctrl-C to quit") logger.info("(%s) Starting notify daemon" % main_pid) - (exit_code, exit_msg) = handle_notifications(configuration) + exit_code, exit_msg = handle_notifications(configuration) stop_msg = "Stopping notify daemon" if exit_code == 0: print(stop_msg) - logger.info("(%s) %s" - % (main_pid, stop_msg)) + logger.info("(%s) %s" % (main_pid, stop_msg)) else: stop_msg += ", exit_code: %s, %s" % (exit_code, exit_msg) print(stop_msg) diff --git a/mig/server/grid_script.py b/mig/server/grid_script.py index fb403f7f7..37fe3a5e2 100755 --- a/mig/server/grid_script.py +++ b/mig/server/grid_script.py @@ -28,31 +28,42 @@ """Main script running on the MiG server""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import range -from past.builtins import basestring -import sys -import time -import datetime import calendar -import threading +import copy +import datetime import os import signal -import copy +import sys +import threading +import time +from builtins import range + +from past.builtins import basestring from mig.server import jobscriptgenerator from mig.server.jobqueue import JobQueue from mig.shared.base import client_id_dir, generate_https_urls from mig.shared.conf import get_configuration_object, get_resource_exe from mig.shared.defaults import default_vgrid, maxfill_fields -from mig.shared.fileio import pickle, unpickle, unpickle_and_change_status, \ - send_message_to_grid_script -from mig.shared.gridscript import clean_grid_stdin, \ - remove_jobrequest_pending_files, check_mrsl_files, requeue_job, \ - server_cleanup, load_queue, save_queue, load_schedule_cache, \ - save_schedule_cache +from mig.shared.fileio import ( + pickle, + send_message_to_grid_script, + unpickle, + unpickle_and_change_status, +) +from mig.shared.gridscript import ( + check_mrsl_files, + clean_grid_stdin, + load_queue, + load_schedule_cache, + remove_jobrequest_pending_files, + requeue_job, + save_queue, + save_schedule_cache, + server_cleanup, +) from mig.shared.notification import notify_user_thread from mig.shared.resadm import atomic_resource_exe_restart, put_exe_pgid from mig.shared.vgrid import job_fits_res_vgrid, validated_vgrid_list @@ -61,12 +72,12 @@ try: from mig.server import servercomm except ImportError as ime: - print('could not import servercomm, probably due to missing pycurl') + print("could not import servercomm, probably due to missing pycurl") print(ime) -(configuration, logger) = (None, None) -(job_queue, executing_queue, scheduler) = (None, None, None) -(job_time_out_thread, job_time_out_stop) = (None, None) +configuration, logger = (None, None) +job_queue, executing_queue, scheduler = (None, None, None) +job_time_out_thread, job_time_out_stop = (None, None) def hangup_handler(signal, frame): @@ -109,9 +120,9 @@ def time_out_jobs(stop_event): qlen = executing_queue.queue_length() if qlen == 0: - logger.info('No jobs in executing_queue') + logger.info("No jobs in executing_queue") else: - logger.info('time_out_jobs(): %d job(s) in queue' % qlen) + logger.info("time_out_jobs(): %d job(s) in queue" % qlen) # TODO: this is a race - 'Main' may modify executing_queue at # any time! @@ -125,34 +136,41 @@ def time_out_jobs(stop_event): job = executing_queue.get_job(i) if not job: logger.warning( - 'time-out RC? found empty job in slot %d!' % i) + "time-out RC? found empty job in slot %d!" % i + ) continue try: - delay = int(job['EXECUTION_DELAY']) + delay = int(job["EXECUTION_DELAY"]) except Exception as err: logger.warning( - 'no execution delay field: %s Exception: %s' - % (job, err)) + "no execution delay field: %s Exception: %s" + % (job, err) + ) delay = 0 try: - cputime = int(job['CPUTIME']) + cputime = int(job["CPUTIME"]) except Exception as err: - logger.warning('cputime extraction failed for %s! %s' - % (job, err)) + logger.warning( + "cputime extraction failed for %s! %s" % (job, err) + ) cputime = 120 extra_cputime = 90 total_cputime = delay + extra_cputime + cputime - timestamp = job['EXECUTING_TIMESTAMP'] + timestamp = job["EXECUTING_TIMESTAMP"] # the canonical way to convert time.gmtime() to # a datetime... All times in UTC timezone - start_executing_datetime = \ - datetime.datetime.utcfromtimestamp(calendar.timegm( - timestamp)) + start_executing_datetime = ( + datetime.datetime.utcfromtimestamp( + calendar.timegm(timestamp) + ) + ) - last_valid_finish_time = start_executing_datetime\ + last_valid_finish_time = ( + start_executing_datetime + datetime.timedelta(seconds=total_cputime) + ) # now, in utc timezone @@ -160,24 +178,28 @@ def time_out_jobs(stop_event): if now > last_valid_finish_time: logger.info( - 'timing out job %s: allowed time %s, delay %s' - % (job['JOB_ID'], total_cputime, delay)) - grid_script_msg = 'JOBTIMEOUT %s %s %s\n'\ - % (job['UNIQUE_RESOURCE_NAME'], job['EXE' - ], job['JOB_ID']) - send_message_to_grid_script(grid_script_msg, - logger, configuration) + "timing out job %s: allowed time %s, delay %s" + % (job["JOB_ID"], total_cputime, delay) + ) + grid_script_msg = "JOBTIMEOUT %s %s %s\n" % ( + job["UNIQUE_RESOURCE_NAME"], + job["EXE"], + job["JOB_ID"], + ) + send_message_to_grid_script( + grid_script_msg, logger, configuration + ) except Exception as err: - logger.error('time_out_jobs: unexpected exception: %s' % err) - logger.info('time_out_jobs: time out thread terminating') + logger.error("time_out_jobs: unexpected exception: %s" % err) + logger.info("time_out_jobs: time out thread terminating") def clean_shutdown(signum, frame): """Request clean shutdown when pending requests are handled""" - print('--- REQUESTING SAFE SHUTDOWN ---') - shutdown_msg = 'SHUTDOWN\n' + print("--- REQUESTING SAFE SHUTDOWN ---") + shutdown_msg = "SHUTDOWN\n" send_message_to_grid_script(shutdown_msg, logger, configuration) @@ -187,27 +209,28 @@ def graceful_shutdown(): handler to avoid interfering with other active requests. """ - msg = '%s: graceful_shutdown called' % sys.argv[0] + msg = "%s: graceful_shutdown called" % sys.argv[0] print(msg) try: logger.info(msg) job_time_out_stop.set() - print('graceful_shutdown: giving time out thread a chance to terminate') + print("graceful_shutdown: giving time out thread a chance to terminate") # make sure queue gets saved even if timeout thread goes haywire job_time_out_thread.join(5) - print('graceful_shutdown: saving state') - if job_queue and not save_queue(job_queue, job_queue_path, - logger): - logger.warning('failed to save job queue') - if executing_queue and not save_queue(executing_queue, - executing_queue_path, logger): - logger.warning('failed to save executing queue') - if scheduler and not save_schedule_cache(scheduler.get_cache(), - schedule_cache_path, logger): - logger.warning('failed to save scheduler cache') - print('graceful_shutdown: saved state; now blocking for timeout thread') + print("graceful_shutdown: saving state") + if job_queue and not save_queue(job_queue, job_queue_path, logger): + logger.warning("failed to save job queue") + if executing_queue and not save_queue( + executing_queue, executing_queue_path, logger + ): + logger.warning("failed to save executing queue") + if scheduler and not save_schedule_cache( + scheduler.get_cache(), schedule_cache_path, logger + ): + logger.warning("failed to save scheduler cache") + print("graceful_shutdown: saved state; now blocking for timeout thread") # Now make sure timeout thread finishes @@ -242,26 +265,29 @@ def graceful_shutdown(): unless it is available in the default path mig/server/MiGserver.conf """) -logger.info('Starting MiG server') +logger.info("Starting MiG server") # Load queues from file dump if available -job_queue_path = os.path.join(configuration.mig_system_files, - 'job_queue.pickle') -executing_queue_path = os.path.join(configuration.mig_system_files, - 'executing_queue.pickle') -schedule_cache_path = os.path.join(configuration.mig_system_files, - 'schedule_cache.pickle') +job_queue_path = os.path.join( + configuration.mig_system_files, "job_queue.pickle" +) +executing_queue_path = os.path.join( + configuration.mig_system_files, "executing_queue.pickle" +) +schedule_cache_path = os.path.join( + configuration.mig_system_files, "schedule_cache.pickle" +) only_new_jobs = True job_queue = load_queue(job_queue_path, logger) executing_queue = load_queue(executing_queue_path, logger) if not job_queue or not executing_queue: - logger.warning('Could not load queues from previous run') + logger.warning("Could not load queues from previous run") only_new_jobs = False job_queue = JobQueue(logger) executing_queue = JobQueue(logger) else: - logger.info('Loaded queues from previous run') + logger.info("Loaded queues from previous run") # Always use an empty done queue after restart @@ -269,33 +295,42 @@ def graceful_shutdown(): schedule_cache = load_schedule_cache(schedule_cache_path, logger) if not schedule_cache: - logger.warning('Could not load schedule cache from previous run') + logger.warning("Could not load schedule cache from previous run") else: - logger.info('Loaded schedule cache from previous run') + logger.info("Loaded schedule cache from previous run") -logger.info('starting scheduler ' + configuration.sched_alg) -if configuration.sched_alg == 'FirstFit': +logger.info("starting scheduler " + configuration.sched_alg) +if configuration.sched_alg == "FirstFit": from mig.server.firstfitscheduler import FirstFitScheduler + scheduler = FirstFitScheduler(logger, configuration) -elif configuration.sched_alg == 'BestFit': +elif configuration.sched_alg == "BestFit": from mig.server.bestfitscheduler import BestFitScheduler + scheduler = BestFitScheduler(logger, configuration) -elif configuration.sched_alg == 'FairFit': +elif configuration.sched_alg == "FairFit": from mig.server.fairfitscheduler import FairFitScheduler + scheduler = FairFitScheduler(logger, configuration) -elif configuration.sched_alg == 'MaxThroughput': +elif configuration.sched_alg == "MaxThroughput": from mig.server.maxthroughputscheduler import MaxThroughputScheduler + scheduler = MaxThroughputScheduler(logger, configuration) -elif configuration.sched_alg == 'Random': +elif configuration.sched_alg == "Random": from mig.server.randomscheduler import RandomScheduler + scheduler = RandomScheduler(logger, configuration) -elif configuration.sched_alg == 'FIFO': +elif configuration.sched_alg == "FIFO": from mig.server.fifoscheduler import FIFOScheduler + scheduler = FIFOScheduler(logger, configuration) else: from mig.server.firstfitscheduler import FirstFitScheduler - print('Unknown sched_alg %s - using FirstFit scheduler' - % configuration.sched_alg) + + print( + "Unknown sched_alg %s - using FirstFit scheduler" + % configuration.sched_alg + ) scheduler = FirstFitScheduler(logger, configuration) scheduler.attach_job_queue(job_queue) @@ -307,53 +342,57 @@ def graceful_shutdown(): try: if not os.path.exists(configuration.grid_stdin): - logger.info('creating grid_script input pipe %s' - % configuration.grid_stdin) + logger.info( + "creating grid_script input pipe %s" % configuration.grid_stdin + ) try: os.mkfifo(configuration.grid_stdin) except Exception as err: - logger.error('Could not create missing grid_stdin fifo: ' - + '%s exception: %s ' - % (configuration.grid_stdin, err)) - grid_stdin = open(configuration.grid_stdin, 'r') + logger.error( + "Could not create missing grid_stdin fifo: " + + "%s exception: %s " % (configuration.grid_stdin, err) + ) + grid_stdin = open(configuration.grid_stdin, "r") except Exception: - logger.error('failed to open grid_stdin! %s' % sys.exc_info()[0]) + logger.error("failed to open grid_stdin! %s" % sys.exc_info()[0]) sys.exit(1) -logger.info('cleaning pipe') +logger.info("cleaning pipe") clean_grid_stdin(grid_stdin) # Make sure empty job home exists -empty_home = os.path.join(configuration.user_home, - configuration.empty_job_name) +empty_home = os.path.join(configuration.user_home, configuration.empty_job_name) if not os.path.exists(empty_home): - logger.info('creating empty job home dir %s' % empty_home) + logger.info("creating empty job home dir %s" % empty_home) try: os.mkdir(empty_home) except Exception as exc: - logger.error('failed to create empty job home dir %s: %s' - % (empty_home, exc)) + logger.error( + "failed to create empty job home dir %s: %s" % (empty_home, exc) + ) -msg = 'Checking for mRSL files with status parse or queued' +msg = "Checking for mRSL files with status parse or queued" print(msg) logger.info(msg) -check_mrsl_files(configuration, job_queue, executing_queue, - only_new_jobs, logger) +check_mrsl_files( + configuration, job_queue, executing_queue, only_new_jobs, logger +) -msg = 'Cleaning up after pending job requests' +msg = "Cleaning up after pending job requests" print(msg) remove_jobrequest_pending_files(configuration) # start the timer function to check if cputime is exceeded -logger.info('starting time_out_jobs()') +logger.info("starting time_out_jobs()") job_time_out_stop = threading.Event() -job_time_out_thread = threading.Thread(target=time_out_jobs, - args=(job_time_out_stop, )) +job_time_out_thread = threading.Thread( + target=time_out_jobs, args=(job_time_out_stop,) +) job_time_out_thread.start() -msg = 'Starting main loop' +msg = "Starting main loop" print(msg) logger.info(msg) @@ -374,8 +413,8 @@ def graceful_shutdown(): line = grid_stdin.readline() strip_line = line.strip() cap_line = strip_line.upper() - linelist = strip_line.split(' ') - if strip_line == '': + linelist = strip_line.split(" ") + if strip_line == "": if last_read_from_grid_stdin_empty: time.sleep(1) last_read_from_grid_stdin_empty = True @@ -386,7 +425,7 @@ def graceful_shutdown(): else: last_read_from_grid_stdin_empty = False - if cap_line.find('USERJOBFILE ') == 0: + if cap_line.find("USERJOBFILE ") == 0: # ********* ********* # ********* USER JOB ********* @@ -397,39 +436,44 @@ def graceful_shutdown(): # add to queue - file_userjob = configuration.mrsl_files_dir\ - + strip_line.replace('USERJOBFILE ', '') + '.mRSL' + file_userjob = ( + configuration.mrsl_files_dir + + strip_line.replace("USERJOBFILE ", "") + + ".mRSL" + ) dict_userjob = unpickle_and_change_status( - file_userjob, 'QUEUED', logger) + file_userjob, "QUEUED", logger + ) if not dict_userjob: - logger.error('Could not unpickle and change status. ' - + 'Job not enqueued!') + logger.error( + "Could not unpickle and change status. " + "Job not enqueued!" + ) continue # Set owner to be able to do per-user job statistics - user_str = strip_line.replace('USERJOBFILE ', '') - (user_id, filename) = user_str.split(os.sep) + user_str = strip_line.replace("USERJOBFILE ", "") + user_id, filename = user_str.split(os.sep) - dict_userjob['OWNER'] = user_id - dict_userjob['MIGRATE_COUNT'] = "0" + dict_userjob["OWNER"] = user_id + dict_userjob["MIGRATE_COUNT"] = "0" # put job in queue job_queue.enqueue_job(dict_userjob, job_queue.queue_length()) user_dict = {} - user_dict['USER_ID'] = user_id + user_dict["USER_ID"] = user_id # Update list of users - create user if new scheduler.update_users(user_dict) user_dict = scheduler.find_user(user_dict) - user_dict['QUEUE_HIST'].pop(0) - user_dict['QUEUE_HIST'].append(dict_userjob) + user_dict["QUEUE_HIST"].pop(0) + user_dict["QUEUE_HIST"].append(dict_userjob) scheduler.update_seen(user_dict) - elif cap_line.find('SERVERJOBFILE ') == 0: + elif cap_line.find("SERVERJOBFILE ") == 0: # ********* ********* # ********* SERVER JOB ********* @@ -440,18 +484,22 @@ def graceful_shutdown(): # add to queue - file_serverjob = configuration.mrsl_files_dir\ - + strip_line.replace('SERVERJOBFILE ', '') + '.mRSL' + file_serverjob = ( + configuration.mrsl_files_dir + + strip_line.replace("SERVERJOBFILE ", "") + + ".mRSL" + ) dict_serverjob = unpickle(file_serverjob, logger) if dict_serverjob is False: logger.error( - 'Could not unpickle migrated job - not put into queue!') + "Could not unpickle migrated job - not put into queue!" + ) continue # put job in queue job_queue.enqueue_job(dict_serverjob, job_queue.queue_length()) - elif cap_line.find('JOBSCHEDULE ') == 0: + elif cap_line.find("JOBSCHEDULE ") == 0: # ********* ********* # ********* SCHEDULE DUMP ********* @@ -461,7 +509,7 @@ def graceful_shutdown(): logger.info(cap_line) if len(linelist) != 2: - logger.error('Invalid job schedule request %s' % linelist) + logger.error("Invalid job schedule request %s" % linelist) continue # read values @@ -472,22 +520,27 @@ def graceful_shutdown(): job_dict = job_queue.get_job_by_id(job_id) if not job_dict: - logger.info('Job is not in waiting queue - no schedule to update') + logger.info("Job is not in waiting queue - no schedule to update") continue - client_dir = client_id_dir(job_dict['USER_CERT']) - file_serverjob = configuration.mrsl_files_dir + client_dir\ - + os.sep + job_id + '.mRSL' + client_dir = client_id_dir(job_dict["USER_CERT"]) + file_serverjob = ( + configuration.mrsl_files_dir + + client_dir + + os.sep + + job_id + + ".mRSL" + ) dict_serverjob = unpickle(file_serverjob, logger) if dict_serverjob is False: - logger.error('Could not unpickle job - not updating schedule!') + logger.error("Could not unpickle job - not updating schedule!") continue # update and save schedule scheduler.copy_schedule(job_dict, dict_serverjob) pickle(dict_serverjob, file_serverjob, logger) - elif cap_line.find('RESOURCEREQUEST ') == 0: + elif cap_line.find("RESOURCEREQUEST ") == 0: # ********* ********* # ********* RESOURCE REQUEST ********* @@ -495,11 +548,13 @@ def graceful_shutdown(): print(cap_line) logger.info(cap_line) - logger.info('RESOURCEREQUEST: %d job(s) in the queue.' % - job_queue.queue_length()) + logger.info( + "RESOURCEREQUEST: %d job(s) in the queue." + % job_queue.queue_length() + ) if len(linelist) != 8: - logger.error('Invalid resource request %s' % linelist) + logger.error("Invalid resource request %s" % linelist) continue # read values @@ -515,19 +570,21 @@ def graceful_shutdown(): # read resource config file - res_file = os.path.join(configuration.resource_home, - unique_resource_name, 'config') + res_file = os.path.join( + configuration.resource_home, unique_resource_name, "config" + ) resource_config = unpickle(res_file, logger) if resource_config is False: - logger.error('error unpickling resource config for %s' - % unique_resource_name) + logger.error( + "error unpickling resource config for %s" % unique_resource_name + ) continue - sandboxed = resource_config.get('SANDBOX', False) + sandboxed = resource_config.get("SANDBOX", False) # Write the PGID of EXE to PGID file - (status, msg) = put_exe_pgid( + status, msg = put_exe_pgid( configuration.resource_home, unique_resource_name, exe, @@ -539,8 +596,9 @@ def graceful_shutdown(): logger.info(msg) else: logger.error( - 'Problem writing EXE PGID to file, job request aborted: %s' - % msg) + "Problem writing EXE PGID to file, job request aborted: %s" + % msg + ) # we cannot create and dispatch job without pgid written to file! @@ -551,78 +609,85 @@ def graceful_shutdown(): # mark job failed if resource requests a new job and # previously dispatched job is not marked done yet - last_req_file = os.path.join(configuration.resource_home, - unique_resource_name, - 'last_request.%s' % exe) + last_req_file = os.path.join( + configuration.resource_home, + unique_resource_name, + "last_request.%s" % exe, + ) last_req = unpickle(last_req_file, logger) if last_req is False: # last_req could not be pickled, this is probably # because it is the first request from the resource - last_req = {'EMPTY_JOB': True} + last_req = {"EMPTY_JOB": True} - if last_req.get('EMPTY_JOB', False) or not last_req.get('USER_CERT', - None): + if last_req.get("EMPTY_JOB", False) or not last_req.get( + "USER_CERT", None + ): # Dequeue empty job and cleanup (if not already done in FINISH) # This is done to avoid them stacking up in the executing_queue # in case of a faulty resource who keeps requesting jobs - job_dict = \ - executing_queue.dequeue_job_by_id(last_req.get( - 'JOB_ID', ''), log_errors=False) + job_dict = executing_queue.dequeue_job_by_id( + last_req.get("JOB_ID", ""), log_errors=False + ) if job_dict: - logger.info('last job was an empty job which did not finish') + logger.info("last job was an empty job which did not finish") if not server_cleanup( - job_dict['SESSIONID'], - job_dict['IOSESSIONID'], - job_dict['LOCALJOBNAME'], - job_dict['JOB_ID'], + job_dict["SESSIONID"], + job_dict["IOSESSIONID"], + job_dict["LOCALJOBNAME"], + job_dict["JOB_ID"], configuration, logger, ): - logger.error('could not clean up MiG server') + logger.error("could not clean up MiG server") else: - logger.info('last job was an empty job which already finished') + logger.info("last job was an empty job which already finished") else: # open the mRSL file belonging to the last request # and check if the status is FINISHED or CANCELED. - last_job_ok_status_list = ['FINISHED', 'CANCELED'] - client_dir = client_id_dir(last_req['USER_CERT']) - filenamelast = os.path.join(configuration.mrsl_files_dir, - client_dir, - last_req['JOB_ID'] + '.mRSL') + last_job_ok_status_list = ["FINISHED", "CANCELED"] + client_dir = client_id_dir(last_req["USER_CERT"]) + filenamelast = os.path.join( + configuration.mrsl_files_dir, + client_dir, + last_req["JOB_ID"] + ".mRSL", + ) job_dict = unpickle(filenamelast, logger) if job_dict: - if job_dict['STATUS'] not in last_job_ok_status_list: + if job_dict["STATUS"] not in last_job_ok_status_list: last_job_failed = True - exe_job = \ - executing_queue.get_job_by_id(job_dict['JOB_ID' - ]) + exe_job = executing_queue.get_job_by_id(job_dict["JOB_ID"]) # Ignore missing fields - (last_res, last_exe) = ('', '') + last_res, last_exe = ("", "") if exe_job: - if 'UNIQUE_RESOURCE_NAME' in exe_job: - last_res = exe_job['UNIQUE_RESOURCE_NAME'] - if 'EXE' in exe_job: - last_exe = exe_job['EXE'] - - if exe_job and last_res == unique_resource_name\ - and last_exe == exe: + if "UNIQUE_RESOURCE_NAME" in exe_job: + last_res = exe_job["UNIQUE_RESOURCE_NAME"] + if "EXE" in exe_job: + last_exe = exe_job["EXE"] + + if ( + exe_job + and last_res == unique_resource_name + and last_exe == exe + ): logger.info( - '%s:%s requested job and was NOT done with last %s' - % (unique_resource_name, exe, job_dict['JOB_ID'])) - print('YOU ARE NOT DONE WITH %s' % job_dict['JOB_ID']) + "%s:%s requested job and was NOT done with last %s" + % (unique_resource_name, exe, job_dict["JOB_ID"]) + ) + print("YOU ARE NOT DONE WITH %s" % job_dict["JOB_ID"]) # Clear any scheduling data for exe_job before requeue scheduler.clear_schedule(exe_job) requeue_job( exe_job, - 'RESOURCE DIED', + "RESOURCE DIED", job_queue, executing_queue, configuration, @@ -630,60 +695,65 @@ def graceful_shutdown(): ) else: logger.info( - '%s:%s requested job but last %s was rescheduled' - % (unique_resource_name, exe, job_dict['JOB_ID'])) - print('YOUR LAST JOB %s WAS RESCHEDULED' - % job_dict['JOB_ID']) + "%s:%s requested job but last %s was rescheduled" + % (unique_resource_name, exe, job_dict["JOB_ID"]) + ) + print( + "YOUR LAST JOB %s WAS RESCHEDULED" + % job_dict["JOB_ID"] + ) else: - logger.info('%s requested job and previous was done' - % unique_resource_name) - print('OK, last job %s was done' % job_dict['JOB_ID']) + logger.info( + "%s requested job and previous was done" + % unique_resource_name + ) + print("OK, last job %s was done" % job_dict["JOB_ID"]) # Now update resource config fields with requested attributes - resource_config['CPUTIME'] = cputime + resource_config["CPUTIME"] = cputime # overwrite execution_delay attribute - resource_config['EXECUTION_DELAY'] = execution_delay + resource_config["EXECUTION_DELAY"] = execution_delay # overwrite number of available nodes (a pbs resource might not # want a job for all nodes) - resource_config['NODECOUNT'] = nodecount - resource_config['RESOURCE_ID'] = '%s_%s'\ - % (unique_resource_name, exe) + resource_config["NODECOUNT"] = nodecount + resource_config["RESOURCE_ID"] = "%s_%s" % (unique_resource_name, exe) # specify vgrid - (status, exe_conf) = get_resource_exe(resource_config, exe, - logger) + status, exe_conf = get_resource_exe(resource_config, exe, logger) if not status: - logger.error('could not get exe configuration for resource!') + logger.error("could not get exe configuration for resource!") continue - last_request_dict = {'RESOURCE_CONFIG': resource_config, - 'CREATED_TIME': datetime.datetime.now(), - 'STATUS': ''} + last_request_dict = { + "RESOURCE_CONFIG": resource_config, + "CREATED_TIME": datetime.datetime.now(), + "STATUS": "", + } # find the vgrid that should receive the job request last_vgrid = 0 - if not exe_conf.get('vgrid', ''): + if not exe_conf.get("vgrid", ""): # fall back to default vgrid - exe_conf['vgrid'] = [default_vgrid] + exe_conf["vgrid"] = [default_vgrid] - if isinstance(exe_conf['vgrid'], basestring): - exe_conf['vgrid'] = list(exe_conf['vgrid']) - exe_vgrids = exe_conf['vgrid'] + if isinstance(exe_conf["vgrid"], basestring): + exe_conf["vgrid"] = list(exe_conf["vgrid"]) + exe_vgrids = exe_conf["vgrid"] - if 'LAST_VGRID' in last_req: + if "LAST_VGRID" in last_req: # index of last vgrid found - last_vgrid_index = last_req['LAST_VGRID'] + last_vgrid_index = last_req["LAST_VGRID"] # make sure the index is within bounds (some vgrids # might have been removed from conf since last run) @@ -706,24 +776,27 @@ def graceful_shutdown(): vgrids_in_prioritized_order = [] - list_indices = [(last_vgrid + i) % len(exe_vgrids) - for i in range(len(exe_vgrids))] + list_indices = [ + (last_vgrid + i) % len(exe_vgrids) for i in range(len(exe_vgrids)) + ] for index in list_indices: # replace "" with default_vgrid - add_vgrid = exe_conf['vgrid'][index] - if add_vgrid == '': + add_vgrid = exe_conf["vgrid"][index] + if add_vgrid == "": add_vgrid = default_vgrid vgrids_in_prioritized_order.append(add_vgrid) - logger.info('vgrids in prioritized order: %s (last %s)' - % (vgrids_in_prioritized_order, last_vgrid)) + logger.info( + "vgrids in prioritized order: %s (last %s)" + % (vgrids_in_prioritized_order, last_vgrid) + ) # set found values - resource_config['VGRID'] = vgrids_in_prioritized_order - resource_config['LAST_VGRID'] = last_vgrid - last_request_dict['LAST_VGRID'] = last_vgrid + resource_config["VGRID"] = vgrids_in_prioritized_order + resource_config["LAST_VGRID"] = last_vgrid + last_request_dict["LAST_VGRID"] = last_vgrid # Update list of resources @@ -735,12 +808,12 @@ def graceful_shutdown(): # No jobs: Create 'empty' job script and double sleep time if # repeated empty job - if 'EMPTY_JOB' not in last_req: + if "EMPTY_JOB" not in last_req: sleep_factor = 1.0 else: sleep_factor = 2.0 - print('N') - (empty_job, msg) = jobscriptgenerator.create_empty_job( + print("N") + empty_job, msg = jobscriptgenerator.create_empty_job( unique_resource_name, exe, cputime, @@ -750,8 +823,7 @@ def graceful_shutdown(): configuration, logger, ) - (new_job, msg) = \ - jobscriptgenerator.create_job_script( + new_job, msg = jobscriptgenerator.create_job_script( unique_resource_name, exe, empty_job, @@ -761,31 +833,34 @@ def graceful_shutdown(): logger, ) if new_job: - last_request_dict['JOB_ID'] = empty_job['JOB_ID'] - last_request_dict['STATUS'] = 'No jobs in queue' + last_request_dict["JOB_ID"] = empty_job["JOB_ID"] + last_request_dict["STATUS"] = "No jobs in queue" if last_job_failed: - last_request_dict['STATUS'] = \ - 'Last job failed - forced empty job' - last_request_dict['EXECUTING_TIMESTAMP'] = time.gmtime() - last_request_dict['EXECUTION_DELAY'] = \ - empty_job['EXECUTION_DELAY'] - last_request_dict['UNIQUE_RESOURCE_NAME'] = \ - unique_resource_name - last_request_dict['PUBLICNAME'] = resource_config.get( - 'PUBLICNAME', 'HIDDEN') - last_request_dict['EXE'] = exe - last_request_dict['RESOURCE_CONFIG'] = resource_config - last_request_dict['LOCALJOBNAME'] = localjobname - last_request_dict['SESSIONID'] = new_job['SESSIONID'] - last_request_dict['IOSESSIONID'] = new_job['IOSESSIONID'] - last_request_dict['CPUTIME'] = empty_job['CPUTIME'] - last_request_dict['EMPTY_JOB'] = True - - executing_queue.enqueue_job(last_request_dict, - executing_queue.queue_length()) - logger.info('empty job script created') + last_request_dict["STATUS"] = ( + "Last job failed - forced empty job" + ) + last_request_dict["EXECUTING_TIMESTAMP"] = time.gmtime() + last_request_dict["EXECUTION_DELAY"] = empty_job[ + "EXECUTION_DELAY" + ] + last_request_dict["UNIQUE_RESOURCE_NAME"] = unique_resource_name + last_request_dict["PUBLICNAME"] = resource_config.get( + "PUBLICNAME", "HIDDEN" + ) + last_request_dict["EXE"] = exe + last_request_dict["RESOURCE_CONFIG"] = resource_config + last_request_dict["LOCALJOBNAME"] = localjobname + last_request_dict["SESSIONID"] = new_job["SESSIONID"] + last_request_dict["IOSESSIONID"] = new_job["IOSESSIONID"] + last_request_dict["CPUTIME"] = empty_job["CPUTIME"] + last_request_dict["EMPTY_JOB"] = True + + executing_queue.enqueue_job( + last_request_dict, executing_queue.queue_length() + ) + logger.info("empty job script created") else: - msg = 'Failed to create job script: %s' % msg + msg = "Failed to create job script: %s" % msg print(msg) logger.error(msg) continue @@ -808,24 +883,30 @@ def graceful_shutdown(): notify_user_thread( expired, - generate_https_urls(configuration, - '%(auto_base)s/%(auto_bin)s/ls.py', - {}), - 'EXPIRED', + generate_https_urls( + configuration, "%(auto_base)s/%(auto_bin)s/ls.py", {} + ), + "EXPIRED", logger, False, configuration, ) - client_dir = client_id_dir(expired['USER_CERT']) - expired_file = configuration.mrsl_files_dir + client_dir\ - + os.sep + expired['JOB_ID'] + '.mRSL' - - if not unpickle_and_change_status(expired_file, - 'EXPIRED', logger): - logger.error('Could not unpickle and change status. ' + client_dir = client_id_dir(expired["USER_CERT"]) + expired_file = ( + configuration.mrsl_files_dir + + client_dir + + os.sep + + expired["JOB_ID"] + + ".mRSL" + ) - + 'Job could not be officially expired!' - ) + if not unpickle_and_change_status( + expired_file, "EXPIRED", logger + ): + logger.error( + "Could not unpickle and change status. " + + "Job could not be officially expired!" + ) continue # Remove references to expired jobs @@ -842,38 +923,43 @@ def graceful_shutdown(): if not job_dict: break - client_dir = client_id_dir(job_dict['USER_CERT']) - mrsl_filename = configuration.mrsl_files_dir\ - + client_dir + '/' + job_dict['JOB_ID'] + '.mRSL' + client_dir = client_id_dir(job_dict["USER_CERT"]) + mrsl_filename = ( + configuration.mrsl_files_dir + + client_dir + + "/" + + job_dict["JOB_ID"] + + ".mRSL" + ) dummy_dict = unpickle(mrsl_filename, logger) # The job status should be "QUEUED" at this point if dummy_dict is False: - logger.error('error unpickling mrsl in %s' - % mrsl_filename) + logger.error("error unpickling mrsl in %s" % mrsl_filename) continue - if dummy_dict['STATUS'] == 'QUEUED': + if dummy_dict["STATUS"] == "QUEUED": break if not job_dict: # no jobs in the queue fits the resource! - print('X') - logger.info('No jobs in the queue can be executed by ' - + 'resource, queue length: %s' - % job_queue.queue_length()) + print("X") + logger.info( + "No jobs in the queue can be executed by " + + "resource, queue length: %s" % job_queue.queue_length() + ) # Create 'empty' job script and double sleep time if # repeated empty job - if 'EMPTY_JOB' not in last_req: + if "EMPTY_JOB" not in last_req: sleep_factor = 1.0 else: sleep_factor = 2.0 - (empty_job, msg) = jobscriptgenerator.create_empty_job( + empty_job, msg = jobscriptgenerator.create_empty_job( unique_resource_name, exe, cputime, @@ -883,8 +969,7 @@ def graceful_shutdown(): configuration, logger, ) - (new_job, msg) = \ - jobscriptgenerator.create_job_script( + new_job, msg = jobscriptgenerator.create_job_script( unique_resource_name, exe, empty_job, @@ -894,42 +979,44 @@ def graceful_shutdown(): logger, ) if new_job: - last_request_dict['JOB_ID'] = empty_job['JOB_ID'] - last_request_dict['STATUS'] = \ - 'No jobs in queue can be executed by resource' - last_request_dict['EXECUTING_TIMESTAMP'] = \ - time.gmtime() - last_request_dict['EXECUTION_DELAY'] = \ - execution_delay - last_request_dict['UNIQUE_RESOURCE_NAME'] = \ + last_request_dict["JOB_ID"] = empty_job["JOB_ID"] + last_request_dict["STATUS"] = ( + "No jobs in queue can be executed by resource" + ) + last_request_dict["EXECUTING_TIMESTAMP"] = time.gmtime() + last_request_dict["EXECUTION_DELAY"] = execution_delay + last_request_dict["UNIQUE_RESOURCE_NAME"] = ( unique_resource_name - last_request_dict['PUBLICNAME'] = resource_config.get( - 'PUBLICNAME', 'HIDDEN') - last_request_dict['EXE'] = exe - last_request_dict['RESOURCE_CONFIG'] = \ - resource_config - last_request_dict['LOCALJOBNAME'] = localjobname - last_request_dict['SESSIONID'] = new_job['SESSIONID'] - last_request_dict['IOSESSIONID'] = new_job['IOSESSIONID'] - last_request_dict['CPUTIME'] = empty_job['CPUTIME'] - last_request_dict['EMPTY_JOB'] = True - - executing_queue.enqueue_job(last_request_dict, - executing_queue.queue_length()) - logger.info('empty job script created') + ) + last_request_dict["PUBLICNAME"] = resource_config.get( + "PUBLICNAME", "HIDDEN" + ) + last_request_dict["EXE"] = exe + last_request_dict["RESOURCE_CONFIG"] = resource_config + last_request_dict["LOCALJOBNAME"] = localjobname + last_request_dict["SESSIONID"] = new_job["SESSIONID"] + last_request_dict["IOSESSIONID"] = new_job["IOSESSIONID"] + last_request_dict["CPUTIME"] = empty_job["CPUTIME"] + last_request_dict["EMPTY_JOB"] = True + + executing_queue.enqueue_job( + last_request_dict, executing_queue.queue_length() + ) + logger.info("empty job script created") else: # a job has been scheduled to be executed on this # resource: change status in the mRSL file - client_dir = client_id_dir(job_dict['USER_CERT']) - mrsl_filename = os.path.join(configuration.mrsl_files_dir, - client_dir, - job_dict['JOB_ID'] + '.mRSL') + client_dir = client_id_dir(job_dict["USER_CERT"]) + mrsl_filename = os.path.join( + configuration.mrsl_files_dir, + client_dir, + job_dict["JOB_ID"] + ".mRSL", + ) mrsl_dict = unpickle(mrsl_filename, logger) if mrsl_dict: - (new_job, msg) = \ - jobscriptgenerator.create_job_script( + new_job, msg = jobscriptgenerator.create_job_script( unique_resource_name, exe, job_dict, @@ -944,87 +1031,98 @@ def graceful_shutdown(): # Fix legacy VGRID fields - mrsl_dict['VGRID'] = validated_vgrid_list( - configuration, mrsl_dict) + mrsl_dict["VGRID"] = validated_vgrid_list( + configuration, mrsl_dict + ) # Select actual VGrid to use - (match, active_job_vgrid, active_res_vgrid) = \ - job_fits_res_vgrid(mrsl_dict['VGRID'], - vgrids_in_prioritized_order) + match, active_job_vgrid, active_res_vgrid = ( + job_fits_res_vgrid( + mrsl_dict["VGRID"], vgrids_in_prioritized_order + ) + ) # Write executing details to mRSL file - mrsl_dict['STATUS'] = 'EXECUTING' - mrsl_dict['EXECUTING_TIMESTAMP'] = time.gmtime() - mrsl_dict['EXECUTION_DELAY'] = execution_delay - mrsl_dict['UNIQUE_RESOURCE_NAME'] = \ - unique_resource_name - mrsl_dict['PUBLICNAME'] = resource_config.get( - 'PUBLICNAME', 'HIDDEN') - mrsl_dict['EXE'] = exe - mrsl_dict['RESOURCE_VGRID'] = active_res_vgrid - mrsl_dict['RESOURCE_CONFIG'] = resource_config - mrsl_dict['LOCALJOBNAME'] = localjobname - mrsl_dict['SESSIONID'] = new_job['SESSIONID'] - mrsl_dict['IOSESSIONID'] = new_job['IOSESSIONID'] - mrsl_dict['MOUNTSSHPUBLICKEY'] = new_job['MOUNTSSHPUBLICKEY'] - mrsl_dict['MOUNTSSHPRIVATEKEY'] = new_job['MOUNTSSHPRIVATEKEY'] + mrsl_dict["STATUS"] = "EXECUTING" + mrsl_dict["EXECUTING_TIMESTAMP"] = time.gmtime() + mrsl_dict["EXECUTION_DELAY"] = execution_delay + mrsl_dict["UNIQUE_RESOURCE_NAME"] = unique_resource_name + mrsl_dict["PUBLICNAME"] = resource_config.get( + "PUBLICNAME", "HIDDEN" + ) + mrsl_dict["EXE"] = exe + mrsl_dict["RESOURCE_VGRID"] = active_res_vgrid + mrsl_dict["RESOURCE_CONFIG"] = resource_config + mrsl_dict["LOCALJOBNAME"] = localjobname + mrsl_dict["SESSIONID"] = new_job["SESSIONID"] + mrsl_dict["IOSESSIONID"] = new_job["IOSESSIONID"] + mrsl_dict["MOUNTSSHPUBLICKEY"] = new_job[ + "MOUNTSSHPUBLICKEY" + ] + mrsl_dict["MOUNTSSHPRIVATEKEY"] = new_job[ + "MOUNTSSHPRIVATEKEY" + ] # pickle the new version pickle(mrsl_dict, mrsl_filename, logger) - last_request_dict['STATUS'] = 'Job assigned' - last_request_dict['CPUTIME'] = \ - new_job['CPUTIME'] - last_request_dict['EXECUTION_DELAY'] = \ - execution_delay - last_request_dict['NODECOUNT'] = \ - new_job['NODECOUNT'] + last_request_dict["STATUS"] = "Job assigned" + last_request_dict["CPUTIME"] = new_job["CPUTIME"] + last_request_dict["EXECUTION_DELAY"] = execution_delay + last_request_dict["NODECOUNT"] = new_job["NODECOUNT"] # job id and user_cert is used to check if the current # job is done when a resource requests a new job - last_request_dict['JOB_ID'] = new_job['JOB_ID'] - last_request_dict['USER_CERT'] = new_job['USER_CERT'] + last_request_dict["JOB_ID"] = new_job["JOB_ID"] + last_request_dict["USER_CERT"] = new_job["USER_CERT"] # Save actual VGrid for fair VGrid cycling try: vgrid_index = vgrids_in_prioritized_order.index( - active_res_vgrid) + active_res_vgrid + ) except Exception: # fall back to simple increment vgrid_index = last_vgrid - last_request_dict['LAST_VGRID'] = vgrid_index + last_request_dict["LAST_VGRID"] = vgrid_index - print('Job assigned ' + new_job['JOB_ID']) - logger.info('Job %s assigned to %s execution unit %s' - % (new_job['JOB_ID'], - unique_resource_name, exe)) + print("Job assigned " + new_job["JOB_ID"]) + logger.info( + "Job %s assigned to %s execution unit %s" + % (new_job["JOB_ID"], unique_resource_name, exe) + ) - if 'WORKFLOW_TRIGGER_ID' in new_job: + if "WORKFLOW_TRIGGER_ID" in new_job: created, msg = create_workflow_job_history_file( configuration, - new_job['VGRID'][0], - new_job['SESSIONID'], - new_job['JOB_ID'], - mrsl_dict['WORKFLOW_TRIGGER_ID'], - mrsl_dict['WORKFLOW_TRIGGER_PATH'], - mrsl_dict['WORKFLOW_TRIGGER_TIME'], - mrsl_dict['WORKFLOW_PATTERN_NAME'], - mrsl_dict['WORKFLOW_PATTERN_ID'], - mrsl_dict['WORKFLOW_RECIPES'], + new_job["VGRID"][0], + new_job["SESSIONID"], + new_job["JOB_ID"], + mrsl_dict["WORKFLOW_TRIGGER_ID"], + mrsl_dict["WORKFLOW_TRIGGER_PATH"], + mrsl_dict["WORKFLOW_TRIGGER_TIME"], + mrsl_dict["WORKFLOW_PATTERN_NAME"], + mrsl_dict["WORKFLOW_PATTERN_ID"], + mrsl_dict["WORKFLOW_RECIPES"], ) if not created: - logger.error("Could not create job history " - "file %s for job %s. %s" - % (new_job['SESSIONID'], - new_job['JOB_ID'], msg)) + logger.error( + "Could not create job history " + "file %s for job %s. %s" + % ( + new_job["SESSIONID"], + new_job["JOB_ID"], + msg, + ) + ) # else: # logger.debug("Created new history file at: " # "%s" % msg) @@ -1038,23 +1136,26 @@ def graceful_shutdown(): for name in maxfill_fields: active_job[name] = new_job[name] - executing_queue.enqueue_job(active_job, - executing_queue.queue_length()) + executing_queue.enqueue_job( + active_job, executing_queue.queue_length() + ) - print('executing_queue length %d' - % executing_queue.queue_length()) + print( + "executing_queue length %d" + % executing_queue.queue_length() + ) else: # put original job in back in job queue - job_queue.enqueue_job(job_dict, - job_queue.queue_length()) - msg = 'error creating new job script, job requeued' + job_queue.enqueue_job( + job_dict, job_queue.queue_length() + ) + msg = "error creating new job script, job requeued" print(msg) logger.error(msg) else: - logger.error('error unpickling mRSL: %s' - % mrsl_filename) + logger.error("error unpickling mRSL: %s" % mrsl_filename) pickle(last_request_dict, last_req_file, logger) @@ -1087,8 +1188,8 @@ def graceful_shutdown(): # real job scheduled! - if 'VGRID' in job_dict: - original_last_request_dict_vgrids += job_dict['VGRID'] + if "VGRID" in job_dict: + original_last_request_dict_vgrids += job_dict["VGRID"] else: # no vgrid specified, this means default vgrid. @@ -1098,7 +1199,7 @@ def graceful_shutdown(): # overwrite last_request_dict for vgrids that # the resource is in but not executing the job - logger.info('job: %s' % job_dict) + logger.info("job: %s" % job_dict) for res_vgrid in vgrids_in_prioritized_order: if res_vgrid not in original_last_request_dict_vgrids: executing_in_other_vgrids.append(res_vgrid) @@ -1107,40 +1208,52 @@ def graceful_shutdown(): # empty job, make sure this job request is seen on monitors # for all vgrids this resource is in - original_last_request_dict_vgrids = \ - vgrids_in_prioritized_order + original_last_request_dict_vgrids = vgrids_in_prioritized_order # save monitor_last_request files # for vgrid_monitor in original_last_request_dict_vgrids: # loop all vgrids where this resource is taking jobs for vgrid_name in vgrids_in_prioritized_order: - logger.info("vgrid_name: '%s' org '%s' exe '%s'" - % (vgrid_name, - original_last_request_dict_vgrids, - executing_in_other_vgrids)) - - monitor_last_request_file = configuration.vgrid_home\ - + os.sep + vgrid_name + os.sep\ - + 'monitor_last_request_' + unique_resource_name + '_'\ + logger.info( + "vgrid_name: '%s' org '%s' exe '%s'" + % ( + vgrid_name, + original_last_request_dict_vgrids, + executing_in_other_vgrids, + ) + ) + + monitor_last_request_file = ( + configuration.vgrid_home + + os.sep + + vgrid_name + + os.sep + + "monitor_last_request_" + + unique_resource_name + + "_" + exe + ) if vgrid_name in original_last_request_dict_vgrids: - pickle(last_request_dict, monitor_last_request_file, - logger) - logger.info('vgrid_name: %s status: %s' % (vgrid_name, - last_request_dict['STATUS'])) + pickle(last_request_dict, monitor_last_request_file, logger) + logger.info( + "vgrid_name: %s status: %s" + % (vgrid_name, last_request_dict["STATUS"]) + ) elif vgrid_name in executing_in_other_vgrids: # create modified last_request_dict and save new_last_request_dict = copy.deepcopy(last_request_dict) - new_last_request_dict['STATUS'] = \ - 'Executing job for another vgrid' - logger.info('vgrid_name: %s status: %s' % (vgrid_name, - new_last_request_dict['STATUS'])) - pickle(new_last_request_dict, - monitor_last_request_file, logger) + new_last_request_dict["STATUS"] = ( + "Executing job for another vgrid" + ) + logger.info( + "vgrid_name: %s status: %s" + % (vgrid_name, new_last_request_dict["STATUS"]) + ) + pickle(new_last_request_dict, monitor_last_request_file, logger) else: # we should never enter this else, vgrid_name must be in @@ -1148,28 +1261,34 @@ def graceful_shutdown(): # executing_in_other_vgrids logger.error( - 'Entered else condition that never should be entered ' + - 'during creation of last_request_dict in grid_script!' + - " vgrid_name: '%s' not in '%s' or '%s'" - % (vgrid_name, original_last_request_dict_vgrids, - executing_in_other_vgrids)) + "Entered else condition that never should be entered " + + "during creation of last_request_dict in grid_script!" + + " vgrid_name: '%s' not in '%s' or '%s'" + % ( + vgrid_name, + original_last_request_dict_vgrids, + executing_in_other_vgrids, + ) + ) # delete requestnewjob lock - lock_file = os.path.join(configuration.resource_home, - unique_resource_name, - 'jobrequest_pending.%s' % exe) + lock_file = os.path.join( + configuration.resource_home, + unique_resource_name, + "jobrequest_pending.%s" % exe, + ) try: os.remove(lock_file) except OSError as ose: - logger.error('Error removing %s: %s' % (lock_file, ose)) + logger.error("Error removing %s: %s" % (lock_file, ose)) # Experimental pricing code # TODO: update price *after* publishing status so that price fits delay? if configuration.enable_server_dist: scheduler.update_price(resource_config) - elif cap_line.find('RESOURCEFINISHEDJOB ') == 0: + elif cap_line.find("RESOURCEFINISHEDJOB ") == 0: # ********* ********* # ********* RESOURCE FINISHED ********* @@ -1178,11 +1297,13 @@ def graceful_shutdown(): print(cap_line) logger.info(cap_line) - logger.info('RESOURCEFINISHEDJOB: %d job(s) in the queue.' % - job_queue.queue_length()) + logger.info( + "RESOURCEFINISHEDJOB: %d job(s) in the queue." + % job_queue.queue_length() + ) if len(linelist) != 5: - logger.error('Invalid resourcefinishedjob request') + logger.error("Invalid resourcefinishedjob request") continue # read values @@ -1192,35 +1313,40 @@ def graceful_shutdown(): sessionid = linelist[3] job_id = linelist[4] - msg = 'RESOURCEFINISHEDJOB: %s:%s finished job %s id %s'\ - % (res_name, exe_name, sessionid, job_id) + msg = "RESOURCEFINISHEDJOB: %s:%s finished job %s id %s" % ( + res_name, + exe_name, + sessionid, + job_id, + ) job_dict = executing_queue.get_job_by_id(job_id) if not job_dict: - msg += \ - ', but job is not in executing queue, ignoring result.' - elif job_dict['UNIQUE_RESOURCE_NAME'] != res_name\ - or job_dict['EXE'] != exe_name: - msg += \ - ', but job is being executed by %s:%s, ignoring result.'\ - % (job_dict['UNIQUE_RESOURCE_NAME'], job_dict['EXE']) + msg += ", but job is not in executing queue, ignoring result." + elif ( + job_dict["UNIQUE_RESOURCE_NAME"] != res_name + or job_dict["EXE"] != exe_name + ): + msg += ", but job is being executed by %s:%s, ignoring result." % ( + job_dict["UNIQUE_RESOURCE_NAME"], + job_dict["EXE"], + ) else: # Clean up the server for files associated with the finished job if not server_cleanup( - job_dict['SESSIONID'], - job_dict['IOSESSIONID'], - job_dict['LOCALJOBNAME'], + job_dict["SESSIONID"], + job_dict["IOSESSIONID"], + job_dict["LOCALJOBNAME"], job_id, configuration, logger, ): - logger.error('could not clean up MiG server') + logger.error("could not clean up MiG server") - if configuration.enable_server_dist\ - and 'EMPTY_JOB' not in job_dict: + if configuration.enable_server_dist and "EMPTY_JOB" not in job_dict: # TODO: we should probably support resources migrating and # handing back job as first contact with new server @@ -1229,12 +1355,12 @@ def graceful_shutdown(): scheduler.finished_job(res_name, job_dict) executing_queue.dequeue_job_by_id(job_id) - msg += '%s removed from executing queue.' % job_id + msg += "%s removed from executing queue." % job_id # print msg logger.info(msg) - elif cap_line.find('RESTARTEXEFAILED') == 0: + elif cap_line.find("RESTARTEXEFAILED") == 0: # ********* ********* # ********* RESTART EXE FAILED ********* @@ -1243,11 +1369,12 @@ def graceful_shutdown(): print(cap_line) logger.info(cap_line) logger.info( - 'Before restart exe failed: %d job(s) in the executing queue.' % - executing_queue.queue_length()) + "Before restart exe failed: %d job(s) in the executing queue." + % executing_queue.queue_length() + ) if len(linelist) != 4: - logger.error('Invalid restart exe failed request') + logger.error("Invalid restart exe failed request") continue # read values @@ -1256,24 +1383,26 @@ def graceful_shutdown(): exe_name = linelist[2] job_id = linelist[3] - logger.info('Restart exe failed: adding retry job for %s %s' - % (res_name, exe_name)) - (retry_job, msg) = jobscriptgenerator.create_restart_job( + logger.info( + "Restart exe failed: adding retry job for %s %s" + % (res_name, exe_name) + ) + retry_job, msg = jobscriptgenerator.create_restart_job( res_name, exe_name, 300, 1, - 'RESTART-EXE-FAILED', + "RESTART-EXE-FAILED", 0, configuration, logger, ) - executing_queue.enqueue_job(retry_job, - executing_queue.queue_length()) + executing_queue.enqueue_job(retry_job, executing_queue.queue_length()) logger.info( - 'After restart exe failed: %d job(s) in the executing queue.' % - executing_queue.queue_length()) - elif cap_line.find('JOBACTION') == 0: + "After restart exe failed: %d job(s) in the executing queue." + % executing_queue.queue_length() + ) + elif cap_line.find("JOBACTION") == 0: # ********* ********* # ********* JOB STATE CHANGE ********* @@ -1281,11 +1410,12 @@ def graceful_shutdown(): print(cap_line) logger.info(cap_line) - logger.info('Job action: %d job(s) in the queue.' % - job_queue.queue_length()) + logger.info( + "Job action: %d job(s) in the queue." % job_queue.queue_length() + ) if len(linelist) != 6: - logger.error('Invalid job action request') + logger.error("Invalid job action request") continue # read values @@ -1298,25 +1428,30 @@ def graceful_shutdown(): # read resource config file - res_file = os.path.join(configuration.resource_home, - unique_resource_name, 'config') + res_file = os.path.join( + configuration.resource_home, unique_resource_name, "config" + ) resource_config = unpickle(res_file, logger) - other_status_list = ['PARSE'] - queued_status_list = ['QUEUED', 'RETRY', 'FROZEN'] - executing_status_list = ['EXECUTING'] + other_status_list = ["PARSE"] + queued_status_list = ["QUEUED", "RETRY", "FROZEN"] + executing_status_list = ["EXECUTING"] # Only cancel is accepted for non-queued states - if original_status not in queued_status_list and \ - new_status != 'CANCELED': - logger.error('change to %s not supported for jobs in %s states' - % (new_status, ', '.join(other_status_list))) + if ( + original_status not in queued_status_list + and new_status != "CANCELED" + ): + logger.error( + "change to %s not supported for jobs in %s states" + % (new_status, ", ".join(other_status_list)) + ) if original_status in other_status_list: pass elif original_status in queued_status_list: - if new_status == 'CANCELED': + if new_status == "CANCELED": job_dict = job_queue.dequeue_job_by_id(job_id) else: job_dict = job_queue.get_job_by_id(job_id) @@ -1324,7 +1459,7 @@ def graceful_shutdown(): logger.warning("Couldn't find job in queue: %s" % job_id) continue scheduler.clear_schedule(job_dict) - job_dict['STATUS'] = new_status + job_dict["STATUS"] = new_status elif original_status in executing_status_list: # Retrieve job_dict @@ -1332,10 +1467,11 @@ def graceful_shutdown(): num_executing_jobs_before = executing_queue.queue_length() job_dict = executing_queue.dequeue_job_by_id(job_id) num_executing_jobs_after = executing_queue.queue_length() - logger.info('Number of jobs in executing queue. ' - + 'Before cancel: %s. After cancel: %s' - % (num_executing_jobs_before, - num_executing_jobs_after)) + logger.info( + "Number of jobs in executing queue. " + + "Before cancel: %s. After cancel: %s" + % (num_executing_jobs_before, num_executing_jobs_after) + ) if not job_dict: @@ -1344,51 +1480,62 @@ def graceful_shutdown(): # trying to cancel logger.info( - 'Cancel job: Could not get job_dict for executing job') + "Cancel job: Could not get job_dict for executing job" + ) continue if not server_cleanup( - job_dict['SESSIONID'], - job_dict['IOSESSIONID'], - job_dict['LOCALJOBNAME'], - job_dict['JOB_ID'], + job_dict["SESSIONID"], + job_dict["IOSESSIONID"], + job_dict["LOCALJOBNAME"], + job_dict["JOB_ID"], configuration, logger, ): - logger.error('could not clean up MiG server') + logger.error("could not clean up MiG server") - if not resource_config.get('SANDBOX', False): + if not resource_config.get("SANDBOX", False): logger.info( - 'Killing running job with atomic_resource_exe_restart') - (status, msg) = \ - atomic_resource_exe_restart(unique_resource_name, - exe, configuration, logger) + "Killing running job with atomic_resource_exe_restart" + ) + status, msg = atomic_resource_exe_restart( + unique_resource_name, exe, configuration, logger + ) if status: - logger.info('atomic_resource_exe_restart ok: res %s:%s' - % (unique_resource_name, exe)) + logger.info( + "atomic_resource_exe_restart ok: res %s:%s" + % (unique_resource_name, exe) + ) else: logger.error( - 'atomic_resource_exe_restart FAILED: %s res %s:%s' - % (msg, unique_resource_name, exe)) + "atomic_resource_exe_restart FAILED: %s res %s:%s" + % (msg, unique_resource_name, exe) + ) # kill_job_by_exe_restart(unique_resource_name, exe, # configuration, logger) # Make sure we do not loose exes even if restart fails - retry_message = 'RESTARTEXEFAILED %s %s %s\n'\ - % (unique_resource_name, exe, job_id) - send_message_to_grid_script(retry_message, logger, - configuration) - elif cap_line.find('JOBTIMEOUT') == 0: + retry_message = "RESTARTEXEFAILED %s %s %s\n" % ( + unique_resource_name, + exe, + job_id, + ) + send_message_to_grid_script( + retry_message, logger, configuration + ) + elif cap_line.find("JOBTIMEOUT") == 0: print(cap_line) logger.info(cap_line) - logger.info('job timeout: %d job(s) in the executing queue.' % - executing_queue.queue_length()) + logger.info( + "job timeout: %d job(s) in the executing queue." + % executing_queue.queue_length() + ) if len(linelist) != 4: - logger.error('Invalid timeout job request') + logger.error("Invalid timeout job request") continue # read values @@ -1397,14 +1544,15 @@ def graceful_shutdown(): exe_name = linelist[2] jobid = linelist[3] - msg = 'JOBTIMEOUT: %s timed out.' % jobid + msg = "JOBTIMEOUT: %s timed out." % jobid print(msg) logger.info(msg) # read resource config file - res_file = os.path.join(configuration.resource_home, - unique_resource_name, 'config') + res_file = os.path.join( + configuration.resource_home, unique_resource_name, "config" + ) resource_config = unpickle(res_file, logger) # Retrieve job_dict @@ -1414,9 +1562,9 @@ def graceful_shutdown(): # Execution information is removed from job_dict in # requeue_job - save here - exe = '' + exe = "" if job_dict: - exe = job_dict['EXE'] + exe = job_dict["EXE"] # Check if job has already been rescheduled due to resource # failure. Important to match both unique resource and exe @@ -1429,24 +1577,27 @@ def graceful_shutdown(): # session id will be invalidated resulting in rejection # and no automatic restart of exe. - if job_dict and unique_resource_name\ - == job_dict['UNIQUE_RESOURCE_NAME'] and exe_name == exe: - if 'EMPTY_JOB' in job_dict: + if ( + job_dict + and unique_resource_name == job_dict["UNIQUE_RESOURCE_NAME"] + and exe_name == exe + ): + if "EMPTY_JOB" in job_dict: # Empty job timed out, cleanup server and # remove from Executing queue if not server_cleanup( - job_dict['SESSIONID'], - job_dict['IOSESSIONID'], - job_dict['LOCALJOBNAME'], - job_dict['JOB_ID'], + job_dict["SESSIONID"], + job_dict["IOSESSIONID"], + job_dict["LOCALJOBNAME"], + job_dict["JOB_ID"], configuration, logger, ): - logger.error('could not clean up MiG server') + logger.error("could not clean up MiG server") - executing_queue.dequeue_job_by_id(job_dict['JOB_ID']) + executing_queue.dequeue_job_by_id(job_dict["JOB_ID"]) else: # Real job, requeue job @@ -1456,7 +1607,7 @@ def graceful_shutdown(): scheduler.clear_schedule(job_dict) requeue_job( job_dict, - 'JOB TIMEOUT', + "JOB TIMEOUT", job_queue, executing_queue, configuration, @@ -1465,132 +1616,154 @@ def graceful_shutdown(): # Restart non-sandbox resources for all timed out jobs - if not resource_config.get('SANDBOX', False): + if not resource_config.get("SANDBOX", False): # TODO: atomic_resource_exe_restart is not always effective # The imada resources have been seen to hang in wait for input # files loop across an atomic_resource_exe_restart run # (server PGID was 'starting'). - (status, msg) = \ - atomic_resource_exe_restart(unique_resource_name, - exe, configuration, logger) + status, msg = atomic_resource_exe_restart( + unique_resource_name, exe, configuration, logger + ) if status: - logger.info('atomic_resource_exe_restart ok: res %s:%s' - % (unique_resource_name, exe)) + logger.info( + "atomic_resource_exe_restart ok: res %s:%s" + % (unique_resource_name, exe) + ) else: logger.error( - 'atomic_resource_exe_restart FAILED: %s, res %s:%s' - % (msg, unique_resource_name, exe)) + "atomic_resource_exe_restart FAILED: %s, res %s:%s" + % (msg, unique_resource_name, exe) + ) # Make sure we do not loose exes even if restart fails - retry_message = 'RESTARTEXEFAILED %s %s %s\n'\ - % (unique_resource_name, exe_name, - job_dict['JOB_ID']) - send_message_to_grid_script(retry_message, logger, - configuration) - logger.info('requested restart exe retry attempt') - elif cap_line.find('JOBQUEUEINFO') == 0: + retry_message = "RESTARTEXEFAILED %s %s %s\n" % ( + unique_resource_name, + exe_name, + job_dict["JOB_ID"], + ) + send_message_to_grid_script( + retry_message, logger, configuration + ) + logger.info("requested restart exe retry attempt") + elif cap_line.find("JOBQUEUEINFO") == 0: details = linelist[1:] if not details: - details.append('JOB_ID') - logger.info('--- DISPLAYING JOB QUEUE INFORMATION ---\n%s' % - '\n'.join(job_queue.format_queue(details))) + details.append("JOB_ID") + logger.info( + "--- DISPLAYING JOB QUEUE INFORMATION ---\n%s" + % "\n".join(job_queue.format_queue(details)) + ) job_queue.show_queue(details) - elif cap_line.find('DROPQUEUED') == 0: - logger.info('--- REMOVING JOBS FROM JOB QUEUE ---') + elif cap_line.find("DROPQUEUED") == 0: + logger.info("--- REMOVING JOBS FROM JOB QUEUE ---") job_list = linelist[1:] if not job_list: - logger.info('No jobs specified for removal') + logger.info("No jobs specified for removal") for job_id in job_list: try: job_queue.dequeue_job_by_id(job_id) logger.info("Removed job %s from job queue" % job_id) except Exception as exc: - logger.error("Failed to remove job %s from job queue: %s" - % (job_id, exc)) - elif cap_line.find('EXECUTINGQUEUEINFO') == 0: + logger.error( + "Failed to remove job %s from job queue: %s" % (job_id, exc) + ) + elif cap_line.find("EXECUTINGQUEUEINFO") == 0: details = linelist[1:] if not details: - details.append('JOB_ID') - logger.info('--- DISPLAYING EXECUTING QUEUE INFORMATION ---\n%s' % - '\n'.join(executing_queue.format_queue(details))) + details.append("JOB_ID") + logger.info( + "--- DISPLAYING EXECUTING QUEUE INFORMATION ---\n%s" + % "\n".join(executing_queue.format_queue(details)) + ) executing_queue.show_queue(details) - elif cap_line.find('DROPEXECUTING') == 0: - logger.info('--- REMOVING JOBS FROM EXECUTING QUEUE ---') + elif cap_line.find("DROPEXECUTING") == 0: + logger.info("--- REMOVING JOBS FROM EXECUTING QUEUE ---") job_list = linelist[1:] if not job_list: - logger.info('No jobs specified for removal') + logger.info("No jobs specified for removal") for job_id in job_list: try: executing_queue.dequeue_job_by_id(job_id) logger.info("Removed job %s from executing queue" % job_id) except Exception as exc: - logger.error("Failed to remove job %s from exe queue: %s" - % (job_id, exc)) - elif cap_line.find('DONEQUEUEINFO') == 0: + logger.error( + "Failed to remove job %s from exe queue: %s" % (job_id, exc) + ) + elif cap_line.find("DONEQUEUEINFO") == 0: details = linelist[1:] if not details: - details.append('JOB_ID') - logger.info('--- DISPLAYING DONE QUEUE INFORMATION ---\n%s' % - '\n'.join(done_queue.format_queue(details))) + details.append("JOB_ID") + logger.info( + "--- DISPLAYING DONE QUEUE INFORMATION ---\n%s" + % "\n".join(done_queue.format_queue(details)) + ) done_queue.show_queue(details) - elif cap_line.find('DROPDONE') == 0: - logger.info('--- REMOVING JOBS FROM DONE QUEUE ---') + elif cap_line.find("DROPDONE") == 0: + logger.info("--- REMOVING JOBS FROM DONE QUEUE ---") job_list = linelist[1:] if not job_list: - logger.info('No jobs specified for removal') + logger.info("No jobs specified for removal") for job_id in job_list: try: done_queue.dequeue_job_by_id(job_id) logger.info("Removed job %s from done queue" % job_id) except Exception as exc: - logger.error("Failed to remove job %s from exe queue: %s" - % (job_id, exc)) - elif cap_line.find('STARTTIMEOUTTHREAD') == 0: - logger.info('--- STARTING TIME OUT THREAD ---') + logger.error( + "Failed to remove job %s from exe queue: %s" % (job_id, exc) + ) + elif cap_line.find("STARTTIMEOUTTHREAD") == 0: + logger.info("--- STARTING TIME OUT THREAD ---") job_time_out_stop.clear() - job_time_out_thread = threading.Thread(target=time_out_jobs, - args=(job_time_out_stop, )) + job_time_out_thread = threading.Thread( + target=time_out_jobs, args=(job_time_out_stop,) + ) job_time_out_thread.start() - elif cap_line.find('CHECKTIMEOUTTHREAD') == 0: - logger.info('--- CHECKING TIME OUT THREAD ---') - logger.info('--- TIME OUT THREAD IS ALIVE: %s ---' - % job_time_out_thread.is_alive()) - elif cap_line.find('RELOADCONFIG') == 0: - logger.info('--- RELOADING CONFIGURATION ---') + elif cap_line.find("CHECKTIMEOUTTHREAD") == 0: + logger.info("--- CHECKING TIME OUT THREAD ---") + logger.info( + "--- TIME OUT THREAD IS ALIVE: %s ---" + % job_time_out_thread.is_alive() + ) + elif cap_line.find("RELOADCONFIG") == 0: + logger.info("--- RELOADING CONFIGURATION ---") configuration.reload_config(True) - elif cap_line.find('SHUTDOWN') == 0: - logger.info('--- SAFE SHUTDOWN INITIATED ---') - print('--- SAFE SHUTDOWN INITIATED ---') + elif cap_line.find("SHUTDOWN") == 0: + logger.info("--- SAFE SHUTDOWN INITIATED ---") + print("--- SAFE SHUTDOWN INITIATED ---") graceful_shutdown() else: - print('not understood: %s' % cap_line) - logger.error('not understood: %s' % cap_line) + print("not understood: %s" % cap_line) + logger.error("not understood: %s" % cap_line) time.sleep(1) # Experimental distributed server code if configuration.enable_server_dist: - servercomm.exchange_status(configuration, scheduler, - loop_counter) + servercomm.exchange_status(configuration, scheduler, loop_counter) # TMP: Auto restart time out thread until we find the death cause if not job_time_out_thread.is_alive(): - logger.warning('--- TIME OUT THREAD DIED: %s %s %s---' - % (job_time_out_thread, - job_time_out_thread.is_alive(), - job_time_out_stop.is_set())) - logger.info('ressurect time out thread with executing queue:') - logger.info('%s' % executing_queue.show_queue(['ALL'])) + logger.warning( + "--- TIME OUT THREAD DIED: %s %s %s---" + % ( + job_time_out_thread, + job_time_out_thread.is_alive(), + job_time_out_stop.is_set(), + ) + ) + logger.info("ressurect time out thread with executing queue:") + logger.info("%s" % executing_queue.show_queue(["ALL"])) job_time_out_stop.clear() - job_time_out_thread = threading.Thread(target=time_out_jobs, - args=(job_time_out_stop, )) + job_time_out_thread = threading.Thread( + target=time_out_jobs, args=(job_time_out_stop,) + ) job_time_out_thread.start() sys.stdout.flush() loop_counter += 1 - logger.debug('loop ended') + logger.debug("loop ended") diff --git a/mig/server/grid_sshmux.py b/mig/server/grid_sshmux.py index 4b1b98cb7..de8de4724 100755 --- a/mig/server/grid_sshmux.py +++ b/mig/server/grid_sshmux.py @@ -32,8 +32,7 @@ error tolerance. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import os import signal @@ -42,8 +41,7 @@ from time import sleep from mig.shared.base import sandbox_resource -from mig.shared.conf import get_resource_configuration, \ - get_configuration_object +from mig.shared.conf import get_configuration_object, get_resource_configuration from mig.shared.logger import daemon_logger, register_hangup_handler from mig.shared.ssh import execute_on_resource @@ -54,32 +52,34 @@ def persistent_connection(resource_config, logger): """Keep running a persistent master connection""" sleep_secs = 300 - hostname = resource_config['HOSTURL'] + hostname = resource_config["HOSTURL"] # Mark this session as a multiplexing master to avoid races: # see further details in shared/ssh.py - resource_config['SSHMULTIPLEXMASTER'] = True + resource_config["SSHMULTIPLEXMASTER"] = True while True: try: - logger.debug('connecting to %s' % hostname) - (exit_code, executed) = execute_on_resource('sleep %d' - % sleep_secs, - False, resource_config, - logger) + logger.debug("connecting to %s" % hostname) + exit_code, executed = execute_on_resource( + "sleep %d" % sleep_secs, False, resource_config, logger + ) if 0 != exit_code: - msg = 'ssh multiplex %s: %s returned %i' % \ - (hostname, executed, exit_code) + msg = "ssh multiplex %s: %s returned %i" % ( + hostname, + executed, + exit_code, + ) print(msg) # make sure control_socket was cleaned up - host = resource_config['HOSTURL'] - identifier = resource_config['HOSTIDENTIFIER'] - unique_id = '%s.%s' % (host, identifier) - control_socket = \ - os.path.join(configuration.resource_home, - unique_id, 'ssh-multiplexing') + host = resource_config["HOSTURL"] + identifier = resource_config["HOSTIDENTIFIER"] + unique_id = "%s.%s" % (host, identifier) + control_socket = os.path.join( + configuration.resource_home, unique_id, "ssh-multiplexing" + ) try: os.remove(control_socket) except: @@ -87,13 +87,15 @@ def persistent_connection(resource_config, logger): sleep(sleep_secs) except Exception as err: - msg = '%s thread caught exception (%s) - retry later' % \ - (hostname, err) + msg = "%s thread caught exception (%s) - retry later" % ( + hostname, + err, + ) print(msg) logger.error(msg) sleep(sleep_secs) - msg = '%s thread leaving...' % hostname + msg = "%s thread leaving..." % hostname print(msg) logger.info(msg) @@ -103,7 +105,7 @@ def graceful_shutdown(signum, frame): way. """ - msg = '%s: graceful_shutdown called' % sys.argv[0] + msg = "%s: graceful_shutdown called" % sys.argv[0] print(msg) try: logger.info(msg) @@ -112,12 +114,12 @@ def graceful_shutdown(signum, frame): sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger @@ -143,12 +145,13 @@ def graceful_shutdown(signum, frame): persistent_hosts = {} resource_path = configuration.resource_home for unique_resource_name in os.listdir(configuration.resource_home): - res_dir = os.path.realpath(configuration.resource_home + os.sep - + unique_resource_name) + res_dir = os.path.realpath( + configuration.resource_home + os.sep + unique_resource_name + ) # skip all dot dirs - they are from repos etc and _not_ jobs - if res_dir.find(os.sep + '.') != -1: + if res_dir.find(os.sep + ".") != -1: continue if not os.path.isdir(res_dir): continue @@ -156,37 +159,39 @@ def graceful_shutdown(signum, frame): if sandbox_resource(dir_name): continue try: - (status, res_conf) = \ - get_resource_configuration(configuration.resource_home, - unique_resource_name, logger) + status, res_conf = get_resource_configuration( + configuration.resource_home, unique_resource_name, logger + ) if not status: continue - if 'SSHMULTIPLEX' in res_conf and res_conf['SSHMULTIPLEX']: - print('adding multiplexing resource %s' % unique_resource_name) - fqdn = res_conf['HOSTURL'] - res_conf['HOMEDIR'] = res_dir + if "SSHMULTIPLEX" in res_conf and res_conf["SSHMULTIPLEX"]: + print("adding multiplexing resource %s" % unique_resource_name) + fqdn = res_conf["HOSTURL"] + res_conf["HOMEDIR"] = res_dir persistent_hosts[fqdn] = res_conf except Exception as err: # else: # print "ignoring non-multiplexing resource %s" % unique_resource_name - print("Failed to open resource conf '%s': %s" - % (unique_resource_name, err)) + print( + "Failed to open resource conf '%s': %s" + % (unique_resource_name, err) + ) threads = {} # register ctrl+c signal handler to shutdown system gracefully signal.signal(signal.SIGINT, graceful_shutdown) - for (hostname, conf) in persistent_hosts.items(): + for hostname, conf in persistent_hosts.items(): if hostname not in threads: - threads[hostname] = \ - threading.Thread(target=persistent_connection, args=(conf, - logger)) + threads[hostname] = threading.Thread( + target=persistent_connection, args=(conf, logger) + ) threads[hostname].setDaemon(True) threads[hostname].start() - print('Send interrupt (ctrl-c) twice to stop persistent connections') + print("Send interrupt (ctrl-c) twice to stop persistent connections") while True: sleep(60) diff --git a/mig/server/grid_transfers.py b/mig/server/grid_transfers.py index 843cac19f..cd25914d9 100755 --- a/mig/server/grid_transfers.py +++ b/mig/server/grid_transfers.py @@ -31,10 +31,8 @@ Requires rsync and lftp binaries to take care of the actual transfers. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import zip import datetime import glob import logging @@ -45,21 +43,40 @@ import sys import time import traceback +from builtins import zip from mig.shared.base import client_dir_id, client_id_dir from mig.shared.conf import get_configuration_object -from mig.shared.defaults import datatransfers_filename, transfers_log_size, \ - transfers_log_cnt, user_keys_dir, _user_invisible_paths +from mig.shared.defaults import ( + _user_invisible_paths, + datatransfers_filename, + transfers_log_cnt, + transfers_log_size, + user_keys_dir, +) from mig.shared.fileio import makedirs_rec, pickle from mig.shared.logger import daemon_logger, register_hangup_handler from mig.shared.notification import notify_user_thread -from mig.shared.pwcrypto import unscramble_digest, fernet_decrypt_password -from mig.shared.safeeval import subprocess_popen, subprocess_pipe, \ - subprocess_list2cmdline -from mig.shared.transferfunctions import blind_pw, load_data_transfers, \ - update_data_transfer, get_status_dir, sub_pid_list, add_sub_pid, \ - del_sub_pid, kill_sub_pid, add_worker_transfer, del_worker_transfer, \ - all_worker_transfers, get_worker_transfer +from mig.shared.pwcrypto import fernet_decrypt_password, unscramble_digest +from mig.shared.safeeval import ( + subprocess_list2cmdline, + subprocess_pipe, + subprocess_popen, +) +from mig.shared.transferfunctions import ( + add_sub_pid, + add_worker_transfer, + all_worker_transfers, + blind_pw, + del_sub_pid, + del_worker_transfer, + get_status_dir, + get_worker_transfer, + kill_sub_pid, + load_data_transfers, + sub_pid_list, + update_data_transfer, +) from mig.shared.validstring import valid_user_path # Global helper dictionaries with requests for all users @@ -68,7 +85,7 @@ all_workers = {} sub_pid_map = None stop_running = multiprocessing.Event() -(configuration, logger, last_update) = (None, None, 0) +configuration, logger, last_update = (None, None, 0) # Tune default lftp buffer size - the built-in size is 32k, but a 128k buffer # was experimentally determined to provide significantly better throughput on @@ -85,31 +102,32 @@ # and the resulting file turning up corrupted. lftp_sftp_block_bytes = 65536 # Special marker for rsync excludes on list form -RSYNC_EXCLUDES_LIST = '__RSYNC_EXCLUDES_LIST__' +RSYNC_EXCLUDES_LIST = "__RSYNC_EXCLUDES_LIST__" def stop_handler(signal, frame): """A simple signal handler to quit on Ctrl+C (SIGINT) in main""" # Print blank line to avoid mix with Ctrl-C line - print('') + print("") stop_running.set() -def __transfer_log(configuration, client_id, msg, level='info'): +def __transfer_log(configuration, client_id, msg, level="info"): """Wrapper to send a single msg to transfer log file of client_id""" status_dir = get_status_dir(configuration, client_id) log_path = os.path.join(status_dir, configuration.site_transfer_log) makedirs_rec(os.path.dirname(log_path), configuration) - transfers_logger = logging.getLogger('background-transfer') + transfers_logger = logging.getLogger("background-transfer") transfers_logger.setLevel(logging.INFO) handler = logging.handlers.RotatingFileHandler( - log_path, maxBytes=transfers_log_size, backupCount=transfers_log_cnt - 1) - formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + log_path, maxBytes=transfers_log_size, backupCount=transfers_log_cnt - 1 + ) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") handler.setFormatter(formatter) transfers_logger.addHandler(handler) - if level == 'error': + if level == "error": transfers_logger.error(msg) - elif level == 'warning': + elif level == "warning": transfers_logger.warning(msg) else: transfers_logger.info(msg) @@ -120,39 +138,49 @@ def __transfer_log(configuration, client_id, msg, level='info'): def transfer_error(configuration, client_id, msg): """Wrapper to send a single error msg to transfer log of client_id""" - __transfer_log(configuration, client_id, msg, 'error') + __transfer_log(configuration, client_id, msg, "error") def transfer_warn(configuration, client_id, msg): """Wrapper to send a single warn msg to transfer log of client_id""" - __transfer_log(configuration, client_id, msg, 'warning') + __transfer_log(configuration, client_id, msg, "warning") def transfer_info(configuration, client_id, msg): """Wrapper to send a single info msg to transfer log of client_id""" - __transfer_log(configuration, client_id, msg, 'info') + __transfer_log(configuration, client_id, msg, "info") -def transfer_result(configuration, client_id, transfer_dict, exit_code, - out_msg, err_msg): +def transfer_result( + configuration, client_id, transfer_dict, exit_code, out_msg, err_msg +): """Update status file from transfer_dict with the result from transfer that reurned exit_code, out_msg and err_msg. """ time_stamp = datetime.datetime.now().ctime() - transfer_id = transfer_dict['transfer_id'] + transfer_id = transfer_dict["transfer_id"] rel_src = transfer_dict.get("rel_src", False) if not rel_src: - rel_src = ', '.join(transfer_dict['src']) + rel_src = ", ".join(transfer_dict["src"]) res_dir = get_status_dir(configuration, client_id, transfer_id) makedirs_rec(res_dir, configuration) - status_msg = '''%s: %s %s of %s in %s finished with status %s -''' % (time_stamp, transfer_dict['protocol'], transfer_dict['action'], rel_src, - transfer_dict['transfer_id'], exit_code) - out_msg = '%s:\n%s\n' % (time_stamp, out_msg) - err_msg = '%s:\n%s\n' % (time_stamp, err_msg) + status_msg = """%s: %s %s of %s in %s finished with status %s +""" % ( + time_stamp, + transfer_dict["protocol"], + transfer_dict["action"], + rel_src, + transfer_dict["transfer_id"], + exit_code, + ) + out_msg = "%s:\n%s\n" % (time_stamp, out_msg) + err_msg = "%s:\n%s\n" % (time_stamp, err_msg) status = True - for (ext, msg) in [("status", status_msg), ("stdout", out_msg), - ("stderr", err_msg)]: + for ext, msg in [ + ("status", status_msg), + ("stdout", out_msg), + ("stderr", err_msg), + ]: path = os.path.join(res_dir, "%s.%s" % (transfer_id, ext)) try: if os.path.exists(path): @@ -162,8 +190,10 @@ def transfer_result(configuration, client_id, transfer_dict, exit_code, status_fd.write(msg) status_fd.close() except Exception as exc: - logger.error("writing status file %s for %s failed: %s" % - (path, blind_pw(transfer_dict), exc)) + logger.error( + "writing status file %s for %s failed: %s" + % (path, blind_pw(transfer_dict), exc) + ) status = False return status @@ -223,10 +253,10 @@ def get_exclude_list(keyword, sep_char, to_string, user_excludes=[]): NOTE: list is passed into subprocess without shell interpretation so quoting is NOT needed, and in fact would break the excludes. """ - exc_pattern = '%s%s%%s' % (keyword, sep_char) + exc_pattern = "%s%s%%s" % (keyword, sep_char) all_excludes = user_excludes + _user_invisible_paths if to_string: - return ' '.join([exc_pattern % i for i in all_excludes]) + return " ".join([exc_pattern % i for i in all_excludes]) else: return [exc_pattern % i for i in all_excludes] @@ -242,11 +272,11 @@ def get_lftp_target(is_import, is_file, user_excludes=[]): suitable for eventually plugging into the command list from command map. """ lftp_args = [] - src = '%(src)s' - dst = '%(dst)s/' + src = "%(src)s" + dst = "%(dst)s/" # Pack all login and address into src to avoid problems with e.g. Amazon S3 # refusing explicit open on base URL first. - remote = '%(protocol)s://%(username)s:%(password)s@%(fqdn)s:%(port)s/' + remote = "%(protocol)s://%(username)s:%(password)s@%(fqdn)s:%(port)s/" if is_import: src = remote + src else: @@ -268,8 +298,7 @@ def get_lftp_target(is_import, is_file, user_excludes=[]): # There's a slight difference in the handling of exclude in lftp and # rsync with the former using regex and the latter using glob by # default. It is not obvious to unify the format in either way. - exclude_list = get_exclude_list('--exclude', ' ', False, - user_excludes) + exclude_list = get_exclude_list("--exclude", " ", False, user_excludes) # IMPORTANT: Use Resume, follow symlinks and keep all but suid perms. # We DON'T preserve device files, owner/group and can't # preserve timestamps. @@ -298,15 +327,15 @@ def get_rsync_target(is_import, is_file, user_excludes=[], compress=False): # IMPORTANT: Follow symlinks, preserve executability and timestamps. # We DON'T preserve device files, owner/group and other perms. # NOTE: enabling -S (efficient sparse file handling) kills performance. - rsync_args = ['-LEt'] + rsync_args = ["-LEt"] if not is_file: - rsync_args[0] += 'r' + rsync_args[0] += "r" if compress: - rsync_args[0] += 'z' + rsync_args[0] += "z" # NOTE: enable argument protection in protocol instead of quoting src/dst - rsync_args[0] += 's' + rsync_args[0] += "s" # NOTE: we actively filter illegal paths from all rsync transfers - exclude_list = get_exclude_list('--exclude', '=', False, user_excludes) + exclude_list = get_exclude_list("--exclude", "=", False, user_excludes) # NOTE: rsync fails with explicit quotes in the path on remote side if we # use the same quoting here as with lftp. We leave them out and rely # on subprocess without shell locally and protect args flag remotely. @@ -353,125 +382,318 @@ def get_cmd_map(): # lftp -c "set xfer:buffer-size $BUFSIZE ; set sftp:size-read $BUFSIZE ; set sftp:size-write $BUFSIZE ; open -u USERNAME,PW -p 22 sftp://io.erda.dk ; get -O build/ build/dblzeros.bin" # lftp -c "set xfer:buffer-size $BUFSIZE; set ftp:ssl-force ; set ftp:ssl-protect-data on ; open -u USERNAME,PW -p 8021 ftp://io.erda.dk ; get -O /tmp welcome.txt" - cmd_map = {'import': - {'sftp': ['lftp', '-c', - ';'.join([lftp_core_opts, sftp_buf_str, sftp_key_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'ftp': ['lftp', '-c', - ';'.join([lftp_core_opts, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'ftps': ['lftp', '-c', - ';'.join([lftp_core_opts, base_ssl_str, ftps_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'http': ['lftp', '-c', - ';'.join([lftp_core_opts, http_tweak_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'https': ['lftp', '-c', - ';'.join([lftp_core_opts, http_tweak_str, - base_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'webdav': ['lftp', '-c', - ';'.join([lftp_core_opts, webdav_tweak_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'webdavs': ['lftp', '-c', - ';'.join([lftp_core_opts, webdav_tweak_str, - base_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'rsyncssh': ['rsync', '-e', rsyncssh_transport_str] + - rsync_core_opts + ['%(rsync_args)s', RSYNC_EXCLUDES_LIST, - '%(rsync_src)s', '%(rsync_dst)s'], - 'rsyncd': ['rsync'] + rsync_core_opts + - ['%(rsync_args)s', RSYNC_EXCLUDES_LIST, - '%(rsync_src)s', '%(rsync_dst)s'], - }, - 'export': - {'sftp': ['lftp', '-c', - ';'.join([lftp_core_opts, sftp_buf_str, sftp_key_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'ftp': ['lftp', '-c', - ';'.join([lftp_core_opts, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'ftps': ['lftp', '-c', - ';'.join([lftp_core_opts, base_ssl_str, ftps_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'http': ['lftp', '-c', - ';'.join([lftp_core_opts, http_tweak_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'https': ['lftp', '-c', - ';'.join([lftp_core_opts, http_tweak_str, base_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'webdav': ['lftp', '-c', - ';'.join([lftp_core_opts, webdav_tweak_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'webdavs': ['lftp', '-c', - ';'.join([lftp_core_opts, webdav_tweak_str, - base_ssl_str, - ' '.join(['%(lftp_args)s', - '%(lftp_excludes)s', - '%(lftp_src)s', '%(lftp_dst)s']) - ])], - 'rsyncssh': ['rsync', '-e', rsyncssh_transport_str] + - rsync_core_opts + ['%(rsync_args)s', RSYNC_EXCLUDES_LIST, - '%(rsync_src)s', '%(rsync_dst)s'], - 'rsyncd': ['rsync'] + rsync_core_opts + - ['%(rsync_args)s', RSYNC_EXCLUDES_LIST, - '%(rsync_src)s', '%(rsync_dst)s'], - } - } + cmd_map = { + "import": { + "sftp": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + sftp_buf_str, + sftp_key_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "ftp": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "ftps": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + base_ssl_str, + ftps_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "http": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + http_tweak_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "https": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + http_tweak_str, + base_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "webdav": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + webdav_tweak_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "webdavs": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + webdav_tweak_str, + base_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "rsyncssh": ["rsync", "-e", rsyncssh_transport_str] + + rsync_core_opts + + [ + "%(rsync_args)s", + RSYNC_EXCLUDES_LIST, + "%(rsync_src)s", + "%(rsync_dst)s", + ], + "rsyncd": ["rsync"] + + rsync_core_opts + + [ + "%(rsync_args)s", + RSYNC_EXCLUDES_LIST, + "%(rsync_src)s", + "%(rsync_dst)s", + ], + }, + "export": { + "sftp": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + sftp_buf_str, + sftp_key_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "ftp": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "ftps": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + base_ssl_str, + ftps_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "http": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + http_tweak_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "https": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + http_tweak_str, + base_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "webdav": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + webdav_tweak_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "webdavs": [ + "lftp", + "-c", + ";".join( + [ + lftp_core_opts, + webdav_tweak_str, + base_ssl_str, + " ".join( + [ + "%(lftp_args)s", + "%(lftp_excludes)s", + "%(lftp_src)s", + "%(lftp_dst)s", + ] + ), + ] + ), + ], + "rsyncssh": ["rsync", "-e", rsyncssh_transport_str] + + rsync_core_opts + + [ + "%(rsync_args)s", + RSYNC_EXCLUDES_LIST, + "%(rsync_src)s", + "%(rsync_dst)s", + ], + "rsyncd": ["rsync"] + + rsync_core_opts + + [ + "%(rsync_args)s", + RSYNC_EXCLUDES_LIST, + "%(rsync_src)s", + "%(rsync_dst)s", + ], + }, + } return cmd_map def run_transfer(configuration, client_id, transfer_dict): """Actual data transfer built from transfer_dict on behalf of client_id""" - logger.debug('run transfer for %s: %s' % (client_id, - blind_pw(transfer_dict))) - transfer_id = transfer_dict['transfer_id'] - action = transfer_dict['action'] - protocol = transfer_dict['protocol'] + logger.debug( + "run transfer for %s: %s" % (client_id, blind_pw(transfer_dict)) + ) + transfer_id = transfer_dict["transfer_id"] + action = transfer_dict["action"] + protocol = transfer_dict["protocol"] status_dir = get_status_dir(configuration, client_id, transfer_id) cmd_map = get_cmd_map() if not protocol in cmd_map[action]: - raise ValueError('unsupported protocol: %s' % protocol) + raise ValueError("unsupported protocol: %s" % protocol) client_dir = client_id_dir(client_id) makedirs_rec(status_dir, configuration) @@ -479,31 +701,40 @@ def run_transfer(configuration, client_id, transfer_dict): # Please note that base_dir must end in slash to avoid access to other # user dirs when own name is a prefix of another user name - base_dir = os.path.abspath(os.path.join(configuration.user_home, - client_dir)) + os.sep + base_dir = ( + os.path.abspath(os.path.join(configuration.user_home, client_dir)) + + os.sep + ) # TODO: we should refactor to move command extraction into one function command_pattern = cmd_map[action][protocol] target_helper_list = [] key_path = transfer_dict.get("key", "") if key_path: # Use key with given name from settings dir - settings_base_dir = os.path.abspath(os.path.join( - configuration.user_settings, client_dir)) + os.sep - key_path = os.path.join(settings_base_dir, user_keys_dir, - key_path.lstrip(os.sep)) + settings_base_dir = ( + os.path.abspath( + os.path.join(configuration.user_settings, client_dir) + ) + + os.sep + ) + key_path = os.path.join( + settings_base_dir, user_keys_dir, key_path.lstrip(os.sep) + ) # IMPORTANT: path must be expanded to abs for proper chrooting key_path = os.path.abspath(key_path) if not valid_user_path(configuration, key_path, settings_base_dir): - logger.error('rejecting illegal directory traversal for %s (%s)' - % (key_path, blind_pw(transfer_dict))) + logger.error( + "rejecting illegal directory traversal for %s (%s)" + % (key_path, blind_pw(transfer_dict)) + ) raise ValueError("user provided a key outside own settings!") - rel_src_list = transfer_dict['src'] - rel_dst = transfer_dict['dst'] + rel_src_list = transfer_dict["src"] + rel_dst = transfer_dict["dst"] compress = transfer_dict.get("compress", False) exclude = transfer_dict.get("exclude", []) - if transfer_dict['action'] in ('import', ): - logger.debug('setting abs dst for action %(action)s' % transfer_dict) - src_path_list = transfer_dict['src'] + if transfer_dict["action"] in ("import",): + logger.debug("setting abs dst for action %(action)s" % transfer_dict) + src_path_list = transfer_dict["src"] dst_path = os.path.join(base_dir, rel_dst.lstrip(os.sep)) dst_path = os.path.abspath(dst_path) for src in rel_src_list: @@ -512,25 +743,29 @@ def run_transfer(configuration, client_id, transfer_dict): abs_dst = os.path.abspath(abs_dst) # Reject illegal directory traversal and hidden files if not valid_user_path(configuration, abs_dst, base_dir, True): - logger.error('rejecting illegal directory traversal for %s (%s)' - % (abs_dst, blind_pw(transfer_dict))) + logger.error( + "rejecting illegal directory traversal for %s (%s)" + % (abs_dst, blind_pw(transfer_dict)) + ) raise ValueError("user provided a destination outside home!") if src.endswith(os.sep): - target_helper_list.append((get_lftp_target(True, False, - exclude), - get_rsync_target(True, False, - exclude, - compress))) + target_helper_list.append( + ( + get_lftp_target(True, False, exclude), + get_rsync_target(True, False, exclude, compress), + ) + ) else: - target_helper_list.append((get_lftp_target(True, True, - exclude), - get_rsync_target(True, True, - exclude, - compress))) + target_helper_list.append( + ( + get_lftp_target(True, True, exclude), + get_rsync_target(True, True, exclude, compress), + ) + ) makedirs_rec(dst_path, configuration) - elif transfer_dict['action'] in ('export', ): - logger.debug('setting abs src for action %(action)s' % transfer_dict) - dst_path = transfer_dict['dst'] + elif transfer_dict["action"] in ("export",): + logger.debug("setting abs src for action %(action)s" % transfer_dict) + dst_path = transfer_dict["dst"] src_path_list = [] for src in rel_src_list: src_path = os.path.join(base_dir, src.lstrip(os.sep)) @@ -538,103 +773,115 @@ def run_transfer(configuration, client_id, transfer_dict): src_path = os.path.abspath(src_path) # Reject illegal directory traversal and hidden files if not valid_user_path(configuration, src_path, base_dir, True): - logger.error('rejecting illegal directory traversal for %s (%s)' - % (src, blind_pw(transfer_dict))) + logger.error( + "rejecting illegal directory traversal for %s (%s)" + % (src, blind_pw(transfer_dict)) + ) raise ValueError("user provided a source outside home!") src_path_list.append(src_path) if src.endswith(os.sep) or os.path.isdir(src): - target_helper_list.append((get_lftp_target(False, False, - exclude), - get_rsync_target(False, False, - exclude, - compress))) + target_helper_list.append( + ( + get_lftp_target(False, False, exclude), + get_rsync_target(False, False, exclude, compress), + ) + ) else: - target_helper_list.append((get_lftp_target(False, True, - exclude), - get_rsync_target(False, True, - exclude, - compress))) + target_helper_list.append( + ( + get_lftp_target(False, True, exclude), + get_rsync_target(False, True, exclude, compress), + ) + ) else: - raise ValueError('unsupported action for %(transfer_id)s: %(action)s' - % transfer_dict) + raise ValueError( + "unsupported action for %(transfer_id)s: %(action)s" % transfer_dict + ) run_dict = transfer_dict.copy() - run_dict['log_path'] = os.path.join(status_dir, 'transfer.log') + run_dict["log_path"] = os.path.join(status_dir, "transfer.log") # Use private known hosts file for ssh transfers as explained above # NOTE: known_hosts containing '=' silently leads to rest getting ignored! # use /dev/null to skip host key verification completely for now. # run_dict['known_hosts'] = os.path.join(base_dir, '.ssh', 'known_hosts') - run_dict['known_hosts'] = '/dev/null' + run_dict["known_hosts"] = "/dev/null" # Make sure password is set to empty string as default - run_dict['password'] = run_dict.get('password', '') + run_dict["password"] = run_dict.get("password", "") # TODO: this is a bogus cert path for now - we don't support ssl certs - run_dict['cert'] = run_dict.get('cert', '') + run_dict["cert"] = run_dict.get("cert", "") # IMPORTANT: must be implicit proto or 'ftp://' (not ftps://) and similarly # webdav(s) must use explicit http(s) instead. In both cases we # replace protocol between cmd selection and lftp path expansion - if run_dict['protocol'] == 'ftps': - run_dict['orig_proto'] = run_dict['protocol'] - run_dict['protocol'] = 'ftp' - logger.info('force %(orig_proto)s to %(protocol)s for %(transfer_id)s' - % run_dict) - elif run_dict['protocol'].startswith('webdav'): - run_dict['orig_proto'] = run_dict['protocol'] - run_dict['protocol'] = run_dict['protocol'].replace('webdav', 'http') - logger.info('force %(orig_proto)s to %(protocol)s for %(transfer_id)s' - % run_dict) + if run_dict["protocol"] == "ftps": + run_dict["orig_proto"] = run_dict["protocol"] + run_dict["protocol"] = "ftp" + logger.info( + "force %(orig_proto)s to %(protocol)s for %(transfer_id)s" + % run_dict + ) + elif run_dict["protocol"].startswith("webdav"): + run_dict["orig_proto"] = run_dict["protocol"] + run_dict["protocol"] = run_dict["protocol"].replace("webdav", "http") + logger.info( + "force %(orig_proto)s to %(protocol)s for %(transfer_id)s" + % run_dict + ) if key_path: - rel_key = run_dict['key'] - rel_cert = run_dict['cert'] - run_dict['key'] = key_path - run_dict['cert'] = key_path.replace(rel_key, rel_cert) - run_dict['ssh_auth'] = get_ssh_auth(True, run_dict) - run_dict['ssl_auth'] = get_ssl_auth(True, run_dict) + rel_key = run_dict["key"] + rel_cert = run_dict["cert"] + run_dict["key"] = key_path + run_dict["cert"] = key_path.replace(rel_key, rel_cert) + run_dict["ssh_auth"] = get_ssh_auth(True, run_dict) + run_dict["ssl_auth"] = get_ssl_auth(True, run_dict) else: # Extract encrypted or digest password if set - password_encrypted = run_dict.get('password_encrypted', '') - password_digest = run_dict.get('password_digest', '') + password_encrypted = run_dict.get("password_encrypted", "") + password_digest = run_dict.get("password_digest", "") if password_encrypted: - run_dict['password'] = fernet_decrypt_password(configuration, - password_encrypted) + run_dict["password"] = fernet_decrypt_password( + configuration, password_encrypted + ) elif password_digest: _, _, _, payload = password_digest.split("$") - unscrambled = unscramble_digest(configuration.site_digest_salt, - payload) + unscrambled = unscramble_digest( + configuration.site_digest_salt, payload + ) _, _, password = unscrambled.split(":") - run_dict['password'] = password - run_dict['ssh_auth'] = get_ssh_auth(False, run_dict) - run_dict['ssl_auth'] = get_ssl_auth(False, run_dict) - run_dict['rel_dst'] = rel_dst - run_dict['dst'] = dst_path - run_dict['lftp_buf_size'] = run_dict.get('lftp_buf_size', - lftp_buffer_bytes) - run_dict['lftp_sftp_block_size'] = run_dict.get('sftp_sftp_block_size', - lftp_sftp_block_bytes) + run_dict["password"] = password + run_dict["ssh_auth"] = get_ssh_auth(False, run_dict) + run_dict["ssl_auth"] = get_ssl_auth(False, run_dict) + run_dict["rel_dst"] = rel_dst + run_dict["dst"] = dst_path + run_dict["lftp_buf_size"] = run_dict.get("lftp_buf_size", lftp_buffer_bytes) + run_dict["lftp_sftp_block_size"] = run_dict.get( + "sftp_sftp_block_size", lftp_sftp_block_bytes + ) status = 0 - for (src, rel_src, target_helper) in zip(src_path_list, rel_src_list, - target_helper_list): - (lftp_target, rsync_target) = target_helper - logger.debug('setting up %(action)s for %(src)s' % run_dict) - if run_dict['protocol'] == 'sftp' and not os.path.isabs(src): + for src, rel_src, target_helper in zip( + src_path_list, rel_src_list, target_helper_list + ): + lftp_target, rsync_target = target_helper + logger.debug("setting up %(action)s for %(src)s" % run_dict) + if run_dict["protocol"] == "sftp" and not os.path.isabs(src): # NOTE: lftp interprets sftp://FQDN/SRC as absolute path /SRC # We force relative paths into user home with a tilde. # The resulting sftp://FQDN/~/SRC looks funky but works. - run_dict['src'] = "~/%s" % src + run_dict["src"] = "~/%s" % src else: # All other paths are probably absolute or auto-chrooted anyway - run_dict['src'] = src - run_dict['rel_src'] = rel_src - run_dict['lftp_args'] = ' '.join(lftp_target[0]) % run_dict - run_dict['lftp_excludes'] = ' '.join(lftp_target[1]) + run_dict["src"] = src + run_dict["rel_src"] = rel_src + run_dict["lftp_args"] = " ".join(lftp_target[0]) % run_dict + run_dict["lftp_excludes"] = " ".join(lftp_target[1]) # src and dst may actually be reversed for lftp, but for symmetry ... - run_dict['lftp_src'] = lftp_target[2][0] % run_dict - run_dict['lftp_dst'] = lftp_target[2][1] % run_dict - run_dict['rsync_args'] = ' '.join(rsync_target[0]) % run_dict + run_dict["lftp_src"] = lftp_target[2][0] % run_dict + run_dict["lftp_dst"] = lftp_target[2][1] % run_dict + run_dict["rsync_args"] = " ".join(rsync_target[0]) % run_dict # Preserve excludes on list form for rsync, where it matters run_dict[RSYNC_EXCLUDES_LIST] = rsync_target[1] - run_dict['rsync_src'] = rsync_target[2][0] % run_dict - run_dict['rsync_dst'] = rsync_target[2][1] % run_dict + run_dict["rsync_src"] = rsync_target[2][0] % run_dict + run_dict["rsync_dst"] = rsync_target[2][1] % run_dict blind_dict = blind_pw(run_dict) - logger.debug('expanded vars to %s' % blind_dict) + logger.debug("expanded vars to %s" % blind_dict) # NOTE: Make sure NOT to break rsync excludes on list form as they # won't work if concatenated to a single string in command_list! command_list, blind_list = [], [] @@ -645,66 +892,77 @@ def run_transfer(configuration, client_id, transfer_dict): else: command_list.append(i % run_dict) blind_list.append(i % blind_dict) - command_str = ' '.join(command_list) + command_str = " ".join(command_list) # NOTE: we wrap list entries in quotes for usable log line blind_str = subprocess_list2cmdline(blind_list) - logger.info('run %s on behalf of %s' % (blind_str, client_id)) + logger.info("run %s on behalf of %s" % (blind_str, client_id)) # NOTE: we want utf8-encoded output as text str for mangling below - transfer_proc = subprocess_popen(command_list, - stdout=subprocess_pipe, - stderr=subprocess_pipe, - text=True) + transfer_proc = subprocess_popen( + command_list, + stdout=subprocess_pipe, + stderr=subprocess_pipe, + text=True, + ) # Save transfer_proc.pid for use in clean up during shutdown # in that way we can resume pretty smoothly in next run. sub_pid = transfer_proc.pid - logger.info('%s %s running transfer process %s' % (client_id, - transfer_id, - sub_pid)) - add_sub_pid(configuration, sub_pid_map, client_id, transfer_id, - sub_pid) + logger.info( + "%s %s running transfer process %s" + % (client_id, transfer_id, sub_pid) + ) + add_sub_pid(configuration, sub_pid_map, client_id, transfer_id, sub_pid) out, err = transfer_proc.communicate() exit_code = transfer_proc.wait() status |= exit_code - del_sub_pid(configuration, sub_pid_map, client_id, transfer_id, - sub_pid) - logger.info('done running transfer %s: %s' % (transfer_id, blind_str)) - logger.debug('raw output is: %s' % out) - logger.debug('raw error is: %s' % err) - logger.debug('result was %s' % exit_code) - if not transfer_result(configuration, client_id, run_dict, exit_code, - out.replace(base_dir, ''), - err.replace(base_dir, '')): - logger.error('writing transfer status for %s failed' % transfer_id) - - logger.debug('done handling transfers in %(transfer_id)s' % transfer_dict) - transfer_dict['exit_code'] = status + del_sub_pid(configuration, sub_pid_map, client_id, transfer_id, sub_pid) + logger.info("done running transfer %s: %s" % (transfer_id, blind_str)) + logger.debug("raw output is: %s" % out) + logger.debug("raw error is: %s" % err) + logger.debug("result was %s" % exit_code) + if not transfer_result( + configuration, + client_id, + run_dict, + exit_code, + out.replace(base_dir, ""), + err.replace(base_dir, ""), + ): + logger.error("writing transfer status for %s failed" % transfer_id) + + logger.debug("done handling transfers in %(transfer_id)s" % transfer_dict) + transfer_dict["exit_code"] = status if status == 0: - transfer_dict['status'] = 'DONE' + transfer_dict["status"] = "DONE" else: - transfer_dict['status'] = 'FAILED' + transfer_dict["status"] = "FAILED" def clean_transfer(configuration, client_id, transfer_id, force=False): """Actually clean transfer worker from client_id and transfer_id""" - logger.debug('in cleaning of %s %s' % (client_id, transfer_id)) - worker = get_worker_transfer(configuration, all_workers, client_id, - transfer_id) - logger.debug('cleaning worker %s for %s %s' % (worker, client_id, - transfer_id)) + logger.debug("in cleaning of %s %s" % (client_id, transfer_id)) + worker = get_worker_transfer( + configuration, all_workers, client_id, transfer_id + ) + logger.debug( + "cleaning worker %s for %s %s" % (worker, client_id, transfer_id) + ) del_worker_transfer(configuration, all_workers, client_id, transfer_id) - sub_procs = sub_pid_list(configuration, sub_pid_map, client_id, - transfer_id) - logger.debug('cleaning sub procs %s for %s %s' % (sub_procs, client_id, - transfer_id)) + sub_procs = sub_pid_list(configuration, sub_pid_map, client_id, transfer_id) + logger.debug( + "cleaning sub procs %s for %s %s" % (sub_procs, client_id, transfer_id) + ) for sub_pid in sub_procs: if not force: - logger.warning('left-over child in %s %s: %s' % - (client_id, transfer_id, sub_procs)) + logger.warning( + "left-over child in %s %s: %s" + % (client_id, transfer_id, sub_procs) + ) if not kill_sub_pid(configuration, client_id, transfer_id, sub_pid): - logger.error('could not terminate child process in %s %s: %s' % - (client_id, transfer_id, sub_procs)) - del_sub_pid(configuration, sub_pid_map, client_id, transfer_id, - sub_pid) + logger.error( + "could not terminate child process in %s %s: %s" + % (client_id, transfer_id, sub_procs) + ) + del_sub_pid(configuration, sub_pid_map, client_id, transfer_id, sub_pid) def wrap_run_transfer(configuration, client_id, transfer_dict): @@ -712,51 +970,74 @@ def wrap_run_transfer(configuration, client_id, transfer_dict): caught and logged. Updates state, calls the run_transfer function on input and finally updates state again afterwards. """ - transfer_id = transfer_dict['transfer_id'] - transfer_dict['status'] = "ACTIVE" - transfer_dict['exit_code'] = -1 - all_transfers[client_id][transfer_id]['status'] = transfer_dict['status'] - (save_status, save_msg) = update_data_transfer(configuration, client_id, - transfer_dict) + transfer_id = transfer_dict["transfer_id"] + transfer_dict["status"] = "ACTIVE" + transfer_dict["exit_code"] = -1 + all_transfers[client_id][transfer_id]["status"] = transfer_dict["status"] + save_status, save_msg = update_data_transfer( + configuration, client_id, transfer_dict + ) if not save_status: - logger.error("failed to save %s status for %s: %s" % - (transfer_dict['status'], transfer_id, save_msg)) + logger.error( + "failed to save %s status for %s: %s" + % (transfer_dict["status"], transfer_id, save_msg) + ) return save_status try: run_transfer(configuration, client_id, transfer_dict) except Exception as exc: logger.error("run transfer failed: %s" % exc) logger.error(traceback.format_exc()) - transfer_dict['status'] = "FAILED" - if not transfer_result(configuration, client_id, transfer_dict, - transfer_dict['exit_code'], '', - 'Fatal error during transfer: %s' % exc): - logger.error('writing transfer status for %s failed' % transfer_id) - - all_transfers[client_id][transfer_id]['status'] = transfer_dict['status'] - (save_status, save_msg) = update_data_transfer(configuration, client_id, - transfer_dict) + transfer_dict["status"] = "FAILED" + if not transfer_result( + configuration, + client_id, + transfer_dict, + transfer_dict["exit_code"], + "", + "Fatal error during transfer: %s" % exc, + ): + logger.error("writing transfer status for %s failed" % transfer_id) + + all_transfers[client_id][transfer_id]["status"] = transfer_dict["status"] + save_status, save_msg = update_data_transfer( + configuration, client_id, transfer_dict + ) if not save_status: - logger.error("failed to save %s status for %s: %s" % - (transfer_dict['status'], transfer_id, save_msg)) - - status_msg = '%s %s from %s in %s %s with status code %s' % \ - (transfer_dict['protocol'], transfer_dict['action'], - transfer_dict['fqdn'], transfer_id, transfer_dict['status'], - transfer_dict['exit_code']) - if transfer_dict['status'] == 'FAILED': + logger.error( + "failed to save %s status for %s: %s" + % (transfer_dict["status"], transfer_id, save_msg) + ) + + status_msg = "%s %s from %s in %s %s with status code %s" % ( + transfer_dict["protocol"], + transfer_dict["action"], + transfer_dict["fqdn"], + transfer_id, + transfer_dict["status"], + transfer_dict["exit_code"], + ) + if transfer_dict["status"] == "FAILED": transfer_error(configuration, client_id, status_msg) else: transfer_info(configuration, client_id, status_msg) - notify = transfer_dict.get('notify', False) + notify = transfer_dict.get("notify", False) if notify: - job_dict = {'NOTIFY': [notify], 'JOB_ID': 'NOJOBID', - 'USER_CERT': client_id} + job_dict = { + "NOTIFY": [notify], + "JOB_ID": "NOJOBID", + "USER_CERT": client_id, + } job_dict.update(transfer_dict) logger.info("notify for %(transfer_id)s: %(notify)s" % transfer_dict) notifier = notify_user_thread( - job_dict, [transfer_id, job_dict['status'], status_msg], - 'TRANSFERCOMPLETE', logger, '', configuration) + job_dict, + [transfer_id, job_dict["status"], status_msg], + "TRANSFERCOMPLETE", + logger, + "", + configuration, + ) # Try finishing delivery but do not block forever on one message notifier.join(30) logger.info("finished wrap run transfer %(transfer_id)s" % transfer_dict) @@ -766,36 +1047,39 @@ def background_transfer(configuration, client_id, transfer_dict): """Run a transfer in the background so that it can block without stopping further transfer handling. """ - transfer_id = transfer_dict['transfer_id'] - worker = multiprocessing.Process(target=wrap_run_transfer, - args=(configuration, client_id, - transfer_dict)) + transfer_id = transfer_dict["transfer_id"] + worker = multiprocessing.Process( + target=wrap_run_transfer, args=(configuration, client_id, transfer_dict) + ) worker.start() - add_worker_transfer(configuration, all_workers, client_id, transfer_id, - worker) + add_worker_transfer( + configuration, all_workers, client_id, transfer_id, worker + ) def foreground_transfer(configuration, client_id, transfer_dict): """Run a transfer in the foreground so that it can block without stopping further transfer handling. """ - transfer_id = transfer_dict['transfer_id'] - add_worker_transfer(configuration, all_workers, client_id, transfer_id, - None) + transfer_id = transfer_dict["transfer_id"] + add_worker_transfer( + configuration, all_workers, client_id, transfer_id, None + ) wrap_run_transfer(configuration, client_id, transfer_dict) del_worker_transfer(configuration, all_workers, client_id, transfer_id) def handle_transfer(configuration, client_id, transfer_dict): """Actually handle valid transfer request in transfer_dict""" - logger.debug('in handling of %s %s for %s' % (transfer_dict['transfer_id'], - transfer_dict['action'], - client_id)) - if transfer_dict['status'] == "ACTIVE": - msg = 'transfer service restarted: resume interrupted %(transfer_id)s ' - msg += '%(action)s (please ignore any recent log errors)' + logger.debug( + "in handling of %s %s for %s" + % (transfer_dict["transfer_id"], transfer_dict["action"], client_id) + ) + if transfer_dict["status"] == "ACTIVE": + msg = "transfer service restarted: resume interrupted %(transfer_id)s " + msg += "%(action)s (please ignore any recent log errors)" else: - msg = 'start %(transfer_id)s %(action)s' + msg = "start %(transfer_id)s %(action)s" transfer_info(configuration, client_id, msg % transfer_dict) try: @@ -803,77 +1087,97 @@ def handle_transfer(configuration, client_id, transfer_dict): # foreground_transfer(configuration, client_id, transfer_dict) background_transfer(configuration, client_id, transfer_dict) except Exception as exc: - logger.error('failed to run %s %s from %s: %s (%s)' - % (transfer_dict['protocol'], transfer_dict['action'], - transfer_dict['fqdn'], exc, blind_pw(transfer_dict))) - transfer_error(configuration, client_id, - 'failed to run %s %s from %s: %s' % - (transfer_dict['protocol'], transfer_dict['action'], - transfer_dict['fqdn'], exc)) + logger.error( + "failed to run %s %s from %s: %s (%s)" + % ( + transfer_dict["protocol"], + transfer_dict["action"], + transfer_dict["fqdn"], + exc, + blind_pw(transfer_dict), + ) + ) + transfer_error( + configuration, + client_id, + "failed to run %s %s from %s: %s" + % ( + transfer_dict["protocol"], + transfer_dict["action"], + transfer_dict["fqdn"], + exc, + ), + ) def manage_transfers(configuration): """Manage all updates of saved user data transfer requests""" - logger.debug('manage transfers') + logger.debug("manage transfers") old_transfers = {} - src_pattern = os.path.join(configuration.user_settings, '*', - datatransfers_filename) + src_pattern = os.path.join( + configuration.user_settings, "*", datatransfers_filename + ) for transfers_path in glob.glob(src_pattern): if os.path.getmtime(transfers_path) < last_update: # logger.debug('skip transfer update for unchanged path: %s' % \ # transfers_path) continue - logger.debug('handling update of transfers file: %s' % transfers_path) + logger.debug("handling update of transfers file: %s" % transfers_path) abs_client_dir = os.path.dirname(transfers_path) client_dir = os.path.basename(abs_client_dir) - logger.debug('extracted client dir: %s' % client_dir) + logger.debug("extracted client dir: %s" % client_dir) client_id = client_dir_id(client_dir) - logger.debug('loading transfers for: %s' % client_id) - (load_status, transfers) = load_data_transfers(configuration, - client_id) + logger.debug("loading transfers for: %s" % client_id) + load_status, transfers = load_data_transfers(configuration, client_id) if not load_status: - logger.error('could not load transfer for path: %s' % - transfers_path) + logger.error( + "could not load transfer for path: %s" % transfers_path + ) continue old_transfers[client_id] = all_transfers.get(client_id, {}) all_transfers[client_id] = transfers - for (client_id, transfers) in all_transfers.items(): - for (transfer_id, transfer_dict) in transfers.items(): + for client_id, transfers in all_transfers.items(): + for transfer_id, transfer_dict in transfers.items(): # logger.debug('inspecting transfer:\n%s' % blind_pw(transfer_dict)) - transfer_status = transfer_dict['status'] + transfer_status = transfer_dict["status"] if transfer_status in ("DONE", "FAILED", "PAUSED"): # logger.debug('skip %(status)s transfer %(transfer_id)s' % \ # transfer_dict) continue - if transfer_status in ("ACTIVE", ): - if get_worker_transfer(configuration, all_workers, client_id, - transfer_id): - logger.debug('wait for transfer %(transfer_id)s' % - transfer_dict) + if transfer_status in ("ACTIVE",): + if get_worker_transfer( + configuration, all_workers, client_id, transfer_id + ): + logger.debug( + "wait for transfer %(transfer_id)s" % transfer_dict + ) continue else: - logger.info('restart transfer %(transfer_id)s' % - transfer_dict) - logger.info('handle %(status)s transfer %(transfer_id)s' % - transfer_dict) + logger.info( + "restart transfer %(transfer_id)s" % transfer_dict + ) + logger.info( + "handle %(status)s transfer %(transfer_id)s" % transfer_dict + ) handle_transfer(configuration, client_id, transfer_dict) -if __name__ == '__main__': +if __name__ == "__main__": # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel - if sys.argv[1:] and sys.argv[1] in ['debug', 'info', 'warning', 'error']: + if sys.argv[1:] and sys.argv[1] in ["debug", "info", "warning", "error"]: log_level = sys.argv[1] # Use separate logger - logger = daemon_logger('transfers', configuration.user_transfers_log, - log_level) + logger = daemon_logger( + "transfers", configuration.user_transfers_log, log_level + ) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -885,7 +1189,7 @@ def manage_transfers(configuration): print(err_msg) sys.exit(1) - print('''This is the MiG data transfer handler daemon which runs requested + print("""This is the MiG data transfer handler daemon which runs requested data transfers in the background on behalf of the users. It monitors the saved data transfer files for changes and launches external client processes to take care of the tranfers, writing status and output to a transfer output directory @@ -893,11 +1197,11 @@ def manage_transfers(configuration): Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -''') +""") - print('Starting Data Transfer handler daemon - Ctrl-C to quit') + print("Starting Data Transfer handler daemon - Ctrl-C to quit") - logger.info('Starting data transfer handler daemon') + logger.info("Starting data transfer handler daemon") # IMPORTANT: If SIGINT reaches multiprocessing it kills manager dict # proxies and makes sub_pid_map access fail. Register a signal handler @@ -917,43 +1221,54 @@ def manage_transfers(configuration): try: manage_transfers(configuration) - for (client_id, transfer_id, worker) in \ - all_worker_transfers(configuration, all_workers): + for client_id, transfer_id, worker in all_worker_transfers( + configuration, all_workers + ): if not worker: continue - logger.debug('Checking if %s %s with pid %d is finished' % - (client_id, transfer_id, worker.pid)) + logger.debug( + "Checking if %s %s with pid %d is finished" + % (client_id, transfer_id, worker.pid) + ) worker.join(1) if worker.is_alive(): - logger.debug('Worker for %s %s running with pid %d' % - (client_id, transfer_id, worker.pid)) + logger.debug( + "Worker for %s %s running with pid %d" + % (client_id, transfer_id, worker.pid) + ) else: - logger.info('Removing finished %s %s with pid %d' % - (client_id, transfer_id, worker.pid)) + logger.info( + "Removing finished %s %s with pid %d" + % (client_id, transfer_id, worker.pid) + ) clean_transfer(configuration, client_id, transfer_id) # Throttle down time.sleep(30) except Exception as exc: - print('Caught unexpected exception: %s' % exc) + print("Caught unexpected exception: %s" % exc) time.sleep(10) - print('Cleaning up active transfers') - logger.info('Cleaning up workers to prepare for exit') - for (client_id, transfer_id, worker) in \ - all_worker_transfers(configuration, all_workers): + print("Cleaning up active transfers") + logger.info("Cleaning up workers to prepare for exit") + for client_id, transfer_id, worker in all_worker_transfers( + configuration, all_workers + ): if not worker or not worker.is_alive(): continue # Terminate worker first to stop further handling, then kill any # orphaned subprocesses associated with it for clean resume later - logger.info('Terminating %s %s worker with pid %d' % - (client_id, transfer_id, worker.pid)) + logger.info( + "Terminating %s %s worker with pid %d" + % (client_id, transfer_id, worker.pid) + ) worker.terminate() - logger.info('Terminating any %s %s child processes' % (client_id, - transfer_id)) + logger.info( + "Terminating any %s %s child processes" % (client_id, transfer_id) + ) clean_transfer(configuration, client_id, transfer_id, force=True) - print('Data transfer handler daemon shutting down') - logger.info('Stop data transfer handler daemon') + print("Data transfer handler daemon shutting down") + logger.info("Stop data transfer handler daemon") sys.exit(0) diff --git a/mig/server/importdoi.py b/mig/server/importdoi.py index 819931465..596a95fa1 100755 --- a/mig/server/importdoi.py +++ b/mig/server/importdoi.py @@ -26,44 +26,47 @@ # """Import any missing DOIs from provided URI - useful from cron job""" -from __future__ import print_function -from __future__ import absolute_import + +from __future__ import absolute_import, print_function import json import os -import requests import sys +import requests + from mig.shared.conf import get_configuration_object from mig.shared.defaults import public_archive_doi def __datacite_req(format, query): """Low-level helper to make a request for data from DataCite""" - url = os.path.join('https://api.datacite.org', format, query) - #print "DEBUG: query datacite REST service on %s" % url + url = os.path.join("https://api.datacite.org", format, query) + # print "DEBUG: query datacite REST service on %s" % url response = requests.get(url) if response.status_code != 200: - raise Exception("unexpected response for %s : %s : %s" % - (url, response.status_code, response.text)) - #print "DEBUG: response\n%s" % response.text + raise Exception( + "unexpected response for %s : %s : %s" + % (url, response.status_code, response.text) + ) + # print "DEBUG: response\n%s" % response.text parsed = json.loads(response.text) return parsed def datacite_query(query): """Make a query against the DataCite REST interface""" - return __datacite_req('works', query) + return __datacite_req("works", query) def datacite_full(doi): """Request full DataCite json content for given DOI value. This is the data corresponding with the Download DataCite JSON entry on the DOI search. """ - return __datacite_req('dois/application/vnd.datacite.datacite+json', doi) + return __datacite_req("dois/application/vnd.datacite.datacite+json", doi) -if __name__ == '__main__': +if __name__ == "__main__": if not sys.argv[1:]: print("""USAGE: importdoi.py VALUE @@ -86,16 +89,16 @@ def datacite_full(doi): configuration = get_configuration_object() target = sys.argv[1] dump = True - query = '' - direct = '' + query = "" + direct = "" verbose = False - if target.startswith('query='): + if target.startswith("query="): dump = False - query += '?' + target - elif target.find('/') != -1: + query += "?" + target + elif target.find("/") != -1: direct += target else: - query += '?query=%s' % target + query += "?query=%s" % target if direct: try: @@ -116,14 +119,14 @@ def datacite_full(doi): except Exception as exc: print("ERROR in DataCite request: %s" % exc) sys.exit(2) - #print "DEBUG: parsed datacite response with %d fields" % len(parsed) + # print "DEBUG: parsed datacite response with %d fields" % len(parsed) # parsed is a dicionary with a data entry holding a list of summary # result dicts. The other entry is meta. parsed_index = parsed.get("data", []) - #print "DEBUG: repeat full lookup for individual sparse entries" + # print "DEBUG: repeat full lookup for individual sparse entries" parsed_data = [] for entry in parsed_index: - attributes = entry.get('attributes', {}) + attributes = entry.get("attributes", {}) plain_doi = attributes.get("doi", None) if plain_doi is None: print("WARNING skip full lookup of malformed entry: %s" % entry) @@ -142,17 +145,20 @@ def datacite_full(doi): if not isinstance(entry, dict): print("WARNING skip malformed entry: %s" % entry) continue - #print "DEBUG: handle entry: %s" % entry + # print "DEBUG: handle entry: %s" % entry doi_url = entry.get("id", None) doi = entry.get("doi", None) - archive_url = entry.get('url', '') + archive_url = entry.get("url", "") archive_id = os.path.basename(os.path.dirname(archive_url)) if not archive_id or not doi_url: - print("WARNING DOI or archive ID missing from %s (%s %s)" % \ - (entry, archive_id, doi_url)) + print( + "WARNING DOI or archive ID missing from %s (%s %s)" + % (entry, archive_id, doi_url) + ) continue - archive_root = os.path.join(configuration.wwwpublic, 'archives', - archive_id) + archive_root = os.path.join( + configuration.wwwpublic, "archives", archive_id + ) if not os.path.isdir(archive_root): print("ERROR No archive %s for DOI %s data" % (archive_root, doi)) continue @@ -165,7 +171,7 @@ def datacite_full(doi): new += 1 if dump: print("Save DOI %s for archive %s" % (doi, archive_id)) - doi_fd = open(doi_path, 'w') + doi_fd = open(doi_path, "w") json.dump(entry, doi_fd) doi_fd.close() imported += 1 @@ -174,6 +180,8 @@ def datacite_full(doi): if verbose: print("\t%s" % entry) - print("Found %d existing - and imported %d of %d new DOI entries" % \ - (existing, imported, new)) + print( + "Found %d existing - and imported %d of %d new DOI entries" + % (existing, imported, new) + ) sys.exit(0) diff --git a/mig/server/importusers.py b/mig/server/importusers.py index 9bf2411d8..927e8f066 100755 --- a/mig/server/importusers.py +++ b/mig/server/importusers.py @@ -27,8 +27,7 @@ """Import any missing users from provided URI""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -39,23 +38,33 @@ from mig.shared import returnvalues from mig.shared.accountstate import default_account_expire -from mig.shared.base import fill_user, distinguished_name_to_user, \ - force_native_str +from mig.shared.base import ( + distinguished_name_to_user, + fill_user, + force_native_str, +) from mig.shared.conf import get_configuration_object from mig.shared.defaults import csrf_field, keyword_auto, valid_auth_types from mig.shared.functionality.sendrequestaction import main from mig.shared.handlers import get_csrf_limit, make_csrf_token from mig.shared.output import format_output -from mig.shared.pwcrypto import generate_random_password, unscramble_password, \ - scramble_password +from mig.shared.pwcrypto import ( + generate_random_password, + scramble_password, + unscramble_password, +) from mig.shared.safeinput import valid_password_chars from mig.shared.url import urlopen -from mig.shared.useradm import init_user_adm, default_search, create_user, \ - search_users +from mig.shared.useradm import ( + create_user, + default_search, + init_user_adm, + search_users, +) from mig.shared.vgridaccess import refresh_user_map -def usage(name='importusers.py'): +def usage(name="importusers.py"): """Usage help""" print("""Import users from an external plain text or XML source URI. @@ -78,7 +87,7 @@ def usage(name='importusers.py'): -p PEER_PATTERN Verify in Peers of existing account matching PEER_PATTERN -P PASSWORD Optional PASSWORD to set for user (AUTO to generate one) -v Verbose output -""" % {'name': name}) +""" % {"name": name}) def dump_contents(url, key_path=None, cert_path=None): @@ -106,17 +115,17 @@ def parse_contents(user_data): """ users = [] - for user_creds in re.findall('/[a-zA-Z]+=[^<\n]+', user_data): - #print "DEBUG: handling user %s" % user_creds + for user_creds in re.findall("/[a-zA-Z]+=[^<\n]+", user_data): + # print "DEBUG: handling user %s" % user_creds user_dict = distinguished_name_to_user(user_creds.strip()) users.append(user_dict) return users -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None - auth_type = 'custom' + auth_type = "custom" key_path = None cert_path = None expire = None @@ -125,50 +134,50 @@ def parse_contents(user_data): verbose = False vgrids = [] override_fields = {} - opt_args = 'a:C:c:d:e:fhK:m:p:P:v' + opt_args = "a:C:c:d:e:fhK:m:p:P:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': + for opt, val in opts: + if opt == "-a": auth_type = val - elif opt == '-c': + elif opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': + elif opt == "-e": expire = int(val) - override_fields['expire'] = expire - override_fields['status'] = 'temporal' - elif opt == '-f': + override_fields["expire"] = expire + override_fields["status"] = "temporal" + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-C': + elif opt == "-C": cert_path = val - elif opt == '-K': + elif opt == "-K": key_path = val - elif opt == '-m': + elif opt == "-m": vgrids.append(val) - elif opt == '-p': + elif opt == "-p": peer_pattern = val - override_fields['peer_pattern'] = peer_pattern - override_fields['status'] = 'temporal' - elif opt == '-P': + override_fields["peer_pattern"] = peer_pattern + override_fields["status"] = "temporal" + elif opt == "-P": password = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) sys.exit(1) if not args: - print('Must provide one or more URIs to import from') + print("Must provide one or more URIs to import from") usage() sys.exit(1) @@ -176,27 +185,32 @@ def parse_contents(user_data): for url in args: url_dump = dump_contents(url, key_path, cert_path) users += parse_contents(url_dump) - #print "DEBUG: raw users to import: %s" % users + # print "DEBUG: raw users to import: %s" % users if auth_type not in valid_auth_types: - print('Error: invalid account auth type %r requested (allowed: %s)' % - (auth_type, ', '.join(valid_auth_types))) + print( + "Error: invalid account auth type %r requested (allowed: %s)" + % (auth_type, ", ".join(valid_auth_types)) + ) usage() sys.exit(1) new_users = [] for user_dict in users: id_search = default_search() - id_search['distinguished_name'] = user_dict['distinguished_name'] - (configuration, hits) = search_users(id_search, conf_path, db_path, - verbose) + id_search["distinguished_name"] = user_dict["distinguished_name"] + configuration, hits = search_users( + id_search, conf_path, db_path, verbose + ) if hits: if verbose: - print('Not adding existing user: %(distinguished_name)s' % - user_dict) + print( + "Not adding existing user: %(distinguished_name)s" + % user_dict + ) continue new_users.append(user_dict) - #print "DEBUG: new users to import: %s" % new_users + # print "DEBUG: new users to import: %s" % new_users configuration = get_configuration_object() @@ -205,30 +219,31 @@ def parse_contents(user_data): for user_dict in new_users: fill_user(user_dict) - client_id = user_dict['distinguished_name'] - user_dict['comment'] = 'imported from external URI' + client_id = user_dict["distinguished_name"] + user_dict["comment"] = "imported from external URI" if password == keyword_auto: - print('Auto generating password for user: %s' % client_id) - user_dict['password'] = generate_random_password(configuration) + print("Auto generating password for user: %s" % client_id) + user_dict["password"] = generate_random_password(configuration) elif password: - print('Setting provided password for user: %s' % client_id) - user_dict['password'] = password + print("Setting provided password for user: %s" % client_id) + user_dict["password"] = password else: - print('Setting empty password for user: %s' % client_id) - user_dict['password'] = '' + print("Setting empty password for user: %s" % client_id) + user_dict["password"] = "" # Encode password if set but not already encoded - if user_dict['password']: + if user_dict["password"]: if verbose: - print('Scrambling password for user: %s' % client_id) - user_dict['password'] = scramble_password( - configuration.site_password_salt, user_dict['password']) + print("Scrambling password for user: %s" % client_id) + user_dict["password"] = scramble_password( + configuration.site_password_salt, user_dict["password"] + ) # Force expire - user_dict['expire'] = expire + user_dict["expire"] = expire # NOTE: let non-ID command line values override loaded values - for (key, val) in list(override_fields.items()): + for key, val in list(override_fields.items()): user_dict[key] = val try: @@ -236,47 +251,63 @@ def parse_contents(user_data): except Exception as exc: print(exc) continue - print('Created %s in user database and in file system' % client_id) + print("Created %s in user database and in file system" % client_id) # NOTE: force update user_map before calling sendrequestaction! # create_user does NOT necessarily update it due to caching time. refresh_user_map(configuration) # Needed for CSRF check in safe_handler - form_method = 'post' + form_method = "post" csrf_limit = get_csrf_limit(configuration) - target_op = 'sendrequestaction' - os.environ.update({'SCRIPT_URL': '%s.py' % target_op, - 'REQUEST_METHOD': form_method}) + target_op = "sendrequestaction" + os.environ.update( + {"SCRIPT_URL": "%s.py" % target_op, "REQUEST_METHOD": form_method} + ) for user_dict in new_users: fill_user(user_dict) - client_id = user_dict['distinguished_name'] - csrf_token = make_csrf_token(configuration, form_method, target_op, - client_id, csrf_limit) + client_id = user_dict["distinguished_name"] + csrf_token = make_csrf_token( + configuration, form_method, target_op, client_id, csrf_limit + ) for name in vgrids: - request = {'vgrid_name': [name], 'request_type': ['vgridmember'], - 'request_text': - ['automatic request from importusers script'], - csrf_field: [csrf_token]} - (output_objs, status) = main(client_id, request) + request = { + "vgrid_name": [name], + "request_type": ["vgridmember"], + "request_text": ["automatic request from importusers script"], + csrf_field: [csrf_token], + } + output_objs, status = main(client_id, request) if status == returnvalues.OK: - print('Request for %s membership in %s sent to owners' % - (client_id, name)) + print( + "Request for %s membership in %s sent to owners" + % (client_id, name) + ) else: - print('Request for %s membership in %s with %s failed:' % - (client_id, name, request)) - output_format = 'text' - (ret_code, ret_msg) = status - output = format_output(configuration, target_op, ret_code, - ret_msg, output_objs, output_format) + print( + "Request for %s membership in %s with %s failed:" + % (client_id, name, request) + ) + output_format = "text" + ret_code, ret_msg = status + output = format_output( + configuration, + target_op, + ret_code, + ret_msg, + output_objs, + output_format, + ) # Explicit None means error during output formatting if output is None: - print("ERROR: %s output formatting failed: %s" % - (output_format, output_objs)) - output = 'Error: output could not be correctly delivered!' + print( + "ERROR: %s output formatting failed: %s" + % (output_format, output_objs) + ) + output = "Error: output could not be correctly delivered!" else: print(output) - print('%d new users imported' % len(new_users)) + print("%d new users imported" % len(new_users)) diff --git a/mig/server/indexdoi.py b/mig/server/indexdoi.py index 604a889b3..948f2379a 100755 --- a/mig/server/indexdoi.py +++ b/mig/server/indexdoi.py @@ -27,8 +27,7 @@ """Build index page listing all imported site DOIs - useful from cron job""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import datetime import glob @@ -38,8 +37,12 @@ from mig.shared.conf import get_configuration_object from mig.shared.defaults import public_archive_doi, public_doi_index -from mig.shared.htmlgen import get_xgi_html_header, get_xgi_html_footer, \ - themed_styles, themed_scripts +from mig.shared.htmlgen import ( + get_xgi_html_footer, + get_xgi_html_header, + themed_scripts, + themed_styles, +) def extract_imported_doi_dicts(configuration): @@ -49,11 +52,12 @@ def extract_imported_doi_dicts(configuration): """ _logger = configuration.logger all_imported = [] - doi_import_pattern = os.path.join(configuration.wwwpublic, 'archives', - '*', public_archive_doi) + doi_import_pattern = os.path.join( + configuration.wwwpublic, "archives", "*", public_archive_doi + ) for doi_dump_path in glob.glob(doi_import_pattern): try: - json_fd = open(doi_dump_path, 'rb') + json_fd = open(doi_dump_path, "rb") doi_import = json.load(json_fd) json_fd.close() all_imported.append((os.path.getctime(doi_dump_path), doi_import)) @@ -65,8 +69,8 @@ def extract_imported_doi_dicts(configuration): return date_sorted -if __name__ == '__main__': - if '-h' in sys.argv[1:]: +if __name__ == "__main__": + if "-h" in sys.argv[1:]: print("""USAGE: indexdoi.py [OPTIONS] @@ -98,17 +102,22 @@ def extract_imported_doi_dicts(configuration): print("handling entry for %s" % plain_doi) doi_url = entry.get("id", None) - archive_url = entry.get('url', '') + archive_url = entry.get("url", "") archive_id = os.path.basename(os.path.dirname(archive_url)) if not archive_id or not doi_url: - print("WARNING DOI or archive ID missing from %s (%s %s)" % - (entry, archive_id, doi_url)) + print( + "WARNING DOI or archive ID missing from %s (%s %s)" + % (entry, archive_id, doi_url) + ) continue - archive_root = os.path.join(configuration.wwwpublic, 'archives', - archive_id) + archive_root = os.path.join( + configuration.wwwpublic, "archives", archive_id + ) if not os.path.isdir(archive_root): - print("ERROR No archive %s for DOI %s data" % - (archive_root, plain_doi)) + print( + "ERROR No archive %s for DOI %s data" + % (archive_root, plain_doi) + ) continue doi_path = os.path.join(archive_root, public_archive_doi) if os.path.exists(doi_path): @@ -118,11 +127,12 @@ def extract_imported_doi_dicts(configuration): doi_count += 1 if dump: - fill_helpers = {'short_title': configuration.short_title, - 'update_stamp': now, - 'doi_count': doi_count, - } - publish_title = '%(short_title)s DOI Index' % fill_helpers + fill_helpers = { + "short_title": configuration.short_title, + "update_stamp": now, + "doi_count": doi_count, + } + publish_title = "%(short_title)s DOI Index" % fill_helpers # Fake manager themed style setup for tablesorter layout with site style style_entry = themed_styles(configuration, user_settings={}) @@ -132,16 +142,20 @@ def extract_imported_doi_dicts(configuration): # NOTE: use mark_static to insert classic page top logo like on V2 pages # using staticpage class for flexible skinning. Otherwise index has no # branding/skin whatsoever. - contents = get_xgi_html_header(configuration, publish_title, '', - style_map=style_entry, - script_map=script_entry, - frame=False, - menu=False, - widgets=False, - userstyle=False, - mark_static=True) - - contents += ''' + contents = get_xgi_html_header( + configuration, + publish_title, + "", + style_map=style_entry, + script_map=script_entry, + frame=False, + menu=False, + widgets=False, + userstyle=False, + mark_static=True, + ) + + contents += """

%(short_title)s DOI Index

@@ -154,36 +168,40 @@ def extract_imported_doi_dicts(configuration):
-''' +""" - doi_lines = '' - for (doi, url) in doi_exports: - doi_lines += ''' + doi_lines = "" + for doi, url in doi_exports: + doi_lines += """

%s

-''' % (url, doi) +""" % (url, doi) contents += doi_lines - contents += ''' + contents += """
%s -''' % get_xgi_html_footer(configuration, widgets=False, mark_static=True) +""" % get_xgi_html_footer(configuration, widgets=False, mark_static=True) try: - index_fd = open(doi_index_path, 'w') + index_fd = open(doi_index_path, "w") index_fd.write(contents % fill_helpers) index_fd.close() - msg = "Published index of %d DOIs in %s" % \ - (doi_count, doi_index_path) + msg = "Published index of %d DOIs in %s" % ( + doi_count, + doi_index_path, + ) _logger.info(msg) if verbose: print(msg) except Exception as exc: msg = "failed to write %s: %s" % (doi_index_path, exc) _logger.error(msg) - print("Error writing index of %d DOIs in %s" % - (doi_count, doi_index_path)) + print( + "Error writing index of %d DOIs in %s" + % (doi_count, doi_index_path) + ) sys.exit(1) sys.exit(0) diff --git a/mig/server/managecloud.py b/mig/server/managecloud.py index 158cd5019..4b32940b6 100755 --- a/mig/server/managecloud.py +++ b/mig/server/managecloud.py @@ -29,21 +29,26 @@ belonging to an expired user account. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import pickle import sys +from mig.shared.cloud import ( + cloud_load_instance, + lookup_user_service_value, + restart_cloud_instance, + start_cloud_instance, + status_all_cloud_instances, + status_of_cloud_instance, + stop_cloud_instance, +) from mig.shared.defaults import keyword_all -from mig.shared.useradm import init_user_adm, search_users, default_search -from mig.shared.cloud import lookup_user_service_value, cloud_load_instance, \ - status_all_cloud_instances, start_cloud_instance, stop_cloud_instance, \ - restart_cloud_instance, status_of_cloud_instance +from mig.shared.useradm import default_search, init_user_adm, search_users -def usage(name='managecloud.py'): +def usage(name="managecloud.py"): """Usage help""" print("""Manage cloud instance for users. @@ -57,62 +62,66 @@ def usage(name='managecloud.py'): -h Show this help -I CERT_DN Limit to instances for user with ID (distinguished name) -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None force = False verbose = False user_file = None search_filter = default_search() - opt_args = 'c:d:hfI:v' + opt_args = "c:d:hfI:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-v': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if not args[1:]: - print('Error: at least two non-option arguments are required') + print("Error: at least two non-option arguments are required") usage() sys.exit(1) action = args[0] instance_list = args[1:] - action_map = {'start': start_cloud_instance, 'stop': stop_cloud_instance, - 'restart': restart_cloud_instance, - 'status': status_of_cloud_instance} + action_map = { + "start": start_cloud_instance, + "stop": stop_cloud_instance, + "restart": restart_cloud_instance, + "status": status_of_cloud_instance, + } if not action in action_map: - print('Error: action must be one of %s' % action_map.keys()) + print("Error: action must be one of %s" % action_map.keys()) usage() sys.exit(1) - uid = 'unknown' - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose) + uid = "unknown" + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose + ) services = configuration.cloud_services if not hits: print("No matching users in user DB") @@ -121,38 +130,57 @@ def usage(name='managecloud.py'): # Reuse conf and hits as a sparse user DB for speed conf_path, db_path = configuration, dict(hits) print("Cloud action: %s" % action) - for (uid, user_dict) in hits: + for uid, user_dict in hits: if verbose: print("Checking %s" % uid) for service in services: - cloud_id = service['service_name'] - cloud_title = service['service_title'] + cloud_id = service["service_name"] + cloud_title = service["service_title"] cloud_flavor = service.get("service_provider_flavor", "openstack") max_instances = lookup_user_service_value( - configuration, uid, service, 'service_max_user_instances') + configuration, uid, service, "service_max_user_instances" + ) max_user_instances = int(max_instances) - print('%s cloud instances allowed for %s: %d' % - (cloud_title, uid, max_user_instances)) + print( + "%s cloud instances allowed for %s: %d" + % (cloud_title, uid, max_user_instances) + ) # Load all user instances and show status - saved_instances = cloud_load_instance(configuration, uid, - cloud_id, keyword_all) - instance_fields = ['public_fqdn', 'status'] + saved_instances = cloud_load_instance( + configuration, uid, cloud_id, keyword_all + ) + instance_fields = ["public_fqdn", "status"] status_map = status_all_cloud_instances( - configuration, uid, cloud_id, cloud_flavor, - list(saved_instances), instance_fields) + configuration, + uid, + cloud_id, + cloud_flavor, + list(saved_instances), + instance_fields, + ) action_helper = action_map[action] - for (instance_id, instance_dict) in saved_instances.items(): + for instance_id, instance_dict in saved_instances.items(): if not instance_id in instance_list: continue - instance_label = instance_dict.get('INSTANCE_LABEL', - instance_id) + instance_label = instance_dict.get( + "INSTANCE_LABEL", instance_id + ) # print('%s cloud instance %s (%s) for %s at %s status: %s' % # (cloud_title, instance_label, instance_id, uid, # status_map[instance_id]['public_fqdn'], # status_map[instance_id]['status'])) - result = action_helper(configuration, uid, cloud_id, - cloud_flavor, instance_id) - print('%s cloud instance %s (%s) for %s at %s applied %s: %s' % - (cloud_title, instance_label, instance_id, uid, - status_map[instance_id]['public_fqdn'], action, - result)) + result = action_helper( + configuration, uid, cloud_id, cloud_flavor, instance_id + ) + print( + "%s cloud instance %s (%s) for %s at %s applied %s: %s" + % ( + cloud_title, + instance_label, + instance_id, + uid, + status_map[instance_id]["public_fqdn"], + action, + result, + ) + ) diff --git a/mig/server/notifyexpire.py b/mig/server/notifyexpire.py index ce620d3f2..263b8dbce 100755 --- a/mig/server/notifyexpire.py +++ b/mig/server/notifyexpire.py @@ -37,8 +37,7 @@ configured additional messaging protocols they can also be used. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import datetime import getopt @@ -46,18 +45,34 @@ import time from mig.shared.accountstate import check_account_expire -from mig.shared.cloud import check_cloud_available, cloud_access_allowed, \ - status_all_cloud_instances, cloud_load_instance -from mig.shared.defaults import keyword_auto, gdp_distinguished_field, \ - keyword_all +from mig.shared.cloud import ( + check_cloud_available, + cloud_access_allowed, + cloud_load_instance, + status_all_cloud_instances, +) +from mig.shared.defaults import ( + gdp_distinguished_field, + keyword_all, + keyword_auto, +) from mig.shared.notification import notify_user -from mig.shared.settings import load_ssh, load_ftps, load_davs, load_seafile, \ - load_cloud -from mig.shared.useradm import init_user_adm, search_users, default_search, \ - user_account_notify - - -def usage(name='notifyexpire.py'): +from mig.shared.settings import ( + load_cloud, + load_davs, + load_ftps, + load_seafile, + load_ssh, +) +from mig.shared.useradm import ( + default_search, + init_user_adm, + search_users, + user_account_notify, +) + + +def usage(name="notifyexpire.py"): """Usage help""" print("""Check internal OpenID account expire for user(s) from user @@ -80,11 +95,11 @@ def usage(name='notifyexpire.py'): One or more destinations may be set by combining multiple -e, -s and -a options. -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None verbose = False admin_copy = False @@ -92,189 +107,205 @@ def usage(name='notifyexpire.py'): user_id = None search_filter = default_search() # Default to all users with expire range between now and in 30 days - search_filter['distinguished_name'] = '*' - search_filter['expire_after'] = int(time.time()) - search_filter['expire_before'] = int(time.time() + 30 * 24 * 3600) + search_filter["distinguished_name"] = "*" + search_filter["expire_after"] = int(time.time()) + search_filter["expire_before"] = int(time.time() + 30 * 24 * 3600) # Default to only internal openid warnings - services = ['migoid'] + services = ["migoid"] now = int(time.time()) exit_code = 0 - opt_args = 'aA:B:c:Cd:e:hI:s:S:v' + opt_args = "aA:B:c:Cd:e:hI:s:S:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(keyword_auto) - elif opt == '-A': + for opt, val in opts: + if opt == "-a": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(keyword_auto) + elif opt == "-A": after = now - if val.startswith('+'): + if val.startswith("+"): after += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): after -= int(val[1:]) else: after = int(val) - search_filter['expire_after'] = after - elif opt == '-B': + search_filter["expire_after"] = after + elif opt == "-B": before = now - if val.startswith('+'): + if val.startswith("+"): before += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): before -= int(val[1:]) else: before = int(val) - search_filter['expire_before'] = before - elif opt == '-c': + search_filter["expire_before"] = before + elif opt == "-c": conf_path = val - elif opt == '-C': + elif opt == "-C": admin_copy = True - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(val) - elif opt == '-h': + elif opt == "-e": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(val) + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-s': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-s": val = val.lower() raw_targets[val] = raw_targets.get(val, []) - raw_targets[val].append('SETTINGS') - elif opt == '-S': + raw_targets[val].append("SETTINGS") + elif opt == "-S": # Force unique list of non-empty entries services = list(dict([(i, 0) for i in val.split() if i.strip()])) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose) + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose + ) logger = configuration.logger gdp_prefix = "%s=" % gdp_distinguished_field # NOTE: we already filtered expired accounts here - search_dn = search_filter['distinguished_name'] - before = datetime.datetime.fromtimestamp(search_filter['expire_before']) - after = datetime.datetime.fromtimestamp(search_filter['expire_after']) + search_dn = search_filter["distinguished_name"] + before = datetime.datetime.fromtimestamp(search_filter["expire_before"]) + after = datetime.datetime.fromtimestamp(search_filter["expire_after"]) if verbose: if hits: - print("Check %d expire(s) between %s and %s for user ID '%s'" % - (len(hits), after, before, search_dn)) + print( + "Check %d expire(s) between %s and %s for user ID '%s'" + % (len(hits), after, before, search_dn) + ) else: - print("No expires between %s and %s for user ID '%s'" % - (after, before, search_dn)) + print( + "No expires between %s and %s for user ID '%s'" + % (after, before, search_dn) + ) - if 'cloud' in services and not configuration.site_enable_cloud: + if "cloud" in services and not configuration.site_enable_cloud: print("WARNING: removing cloud which is not enabled on this site!") - services.remove('cloud') - elif 'seafile' in services and not configuration.site_enable_seafile: + services.remove("cloud") + elif "seafile" in services and not configuration.site_enable_seafile: print("WARNING: removing seafile which is not enabled on this site!") - services.remove('seafile') - elif 'migoid' in services and not configuration.site_enable_openid: + services.remove("seafile") + elif "migoid" in services and not configuration.site_enable_openid: print("WARNING: removing migoid which is not enabled on this site!") - services.remove('migoid') + services.remove("migoid") - for (user_id, user_dict) in hits: + for user_id, user_dict in hits: affected = [] if verbose: - print('Check for %s' % user_id) + print("Check for %s" % user_id) - if configuration.site_enable_gdp and \ - user_id.split('/')[-1].startswith(gdp_prefix): + if configuration.site_enable_gdp and user_id.split("/")[-1].startswith( + gdp_prefix + ): if verbose: print("Skip GDP project account: %s" % user_id) continue # Don't warn about already disabled or suspended accounts - account_state = user_dict.get('status', 'active') - if not account_state in ('active', 'temporal'): + account_state = user_dict.get("status", "active") + if not account_state in ("active", "temporal"): if verbose: - print('Skip handling of already %s user %s' % (account_state, - user_id)) + print( + "Skip handling of already %s user %s" + % (account_state, user_id) + ) continue - known_auth = user_dict.get('auth', []) + known_auth = user_dict.get("auth", []) if not known_auth: - if user_dict.get('openid_names', []): - if user_dict.get('password_hash', ''): + if user_dict.get("openid_names", []): + if user_dict.get("password_hash", ""): known_auth.append("migoid") else: known_auth.append("extoid") - elif user_dict.get('password', ''): + elif user_dict.get("password", ""): known_auth.append("migcert") else: if verbose: - print('Skip handling of user %s without auth info' % - user_id) + print( + "Skip handling of user %s without auth info" % user_id + ) continue - elif "migoid" in known_auth and not user_dict.get('password_hash', ''): + elif "migoid" in known_auth and not user_dict.get("password_hash", ""): # Users switching between internal and external auth may end up here if verbose: - print('Skip migoid expire warn for user %s without password' \ - % user_id) + print( + "Skip migoid expire warn for user %s without password" + % user_id + ) known_auth = [i for i in known_auth if i != "migoid"] continue auth_services = [i for i in known_auth if i in services] if auth_services: - (pending_expire, account_expire, _) = check_account_expire( - configuration, user_id) - if account_expire > search_filter['expire_after'] and \ - account_expire < search_filter['expire_before']: + pending_expire, account_expire, _ = check_account_expire( + configuration, user_id + ) + if ( + account_expire > search_filter["expire_after"] + and account_expire < search_filter["expire_before"] + ): affected += auth_services - if 'ssh' in services or 'sftp' in services: + if "ssh" in services or "sftp" in services: svc_dict = load_ssh(user_id, configuration) if not svc_dict: svc_dict = {} - svc_creds = svc_dict.get('authpassword', '') or \ - svc_dict.get('authkeys', '') + svc_creds = svc_dict.get("authpassword", "") or svc_dict.get( + "authkeys", "" + ) if svc_creds: - affected.append('sftp') + affected.append("sftp") - if 'ftps' in services: + if "ftps" in services: svc_dict = load_ftps(user_id, configuration) if not svc_dict: svc_dict = {} - svc_creds = svc_dict.get('authpassword', '') + svc_creds = svc_dict.get("authpassword", "") if svc_creds: - affected.append('ftps') + affected.append("ftps") - if 'davs' in services or 'webdavs' in services: + if "davs" in services or "webdavs" in services: svc_dict = load_davs(user_id, configuration) if not svc_dict: svc_dict = {} - svc_creds = svc_dict.get('authpassword', '') + svc_creds = svc_dict.get("authpassword", "") if svc_creds: - affected.append('webdavs') + affected.append("webdavs") - if 'seafile' in services: + if "seafile" in services: svc_dict = load_seafile(user_id, configuration) if not svc_dict: svc_dict = {} - svc_creds = svc_dict.get('authpassword', '') + svc_creds = svc_dict.get("authpassword", "") if svc_creds: - affected.append('seafile') + affected.append("seafile") - if 'cloud' in services: + if "cloud" in services: if not cloud_access_allowed(configuration, user_dict): if verbose: - print('Skip handling of cloud without access for %s' % - user_id) + print( + "Skip handling of cloud without access for %s" % user_id + ) cloud_services = [] else: cloud_services = configuration.cloud_services @@ -283,53 +314,67 @@ def usage(name='notifyexpire.py'): svc_dict = load_cloud(user_id, configuration) if not svc_dict: svc_dict = {} - svc_creds = svc_dict.get('authkeys', '') + svc_creds = svc_dict.get("authkeys", "") # TODO: only count cloud effected if active instances? # We most likely need to at least remove cloud keys from jump host if svc_creds: - affected.append('cloud') + affected.append("cloud") for cloud_svc in cloud_services: - cloud_id = cloud_svc['service_name'] - cloud_title = cloud_svc['service_title'] - cloud_flavor = cloud_svc.get("service_provider_flavor", - "openstack") - - if not check_cloud_available(configuration, user_id, cloud_id, - cloud_flavor): + cloud_id = cloud_svc["service_name"] + cloud_title = cloud_svc["service_title"] + cloud_flavor = cloud_svc.get( + "service_provider_flavor", "openstack" + ) + + if not check_cloud_available( + configuration, user_id, cloud_id, cloud_flavor + ): if verbose: - print('Skip handling of unavailable cloud %s for %s' % - (cloud_title, user_id)) + print( + "Skip handling of unavailable cloud %s for %s" + % (cloud_title, user_id) + ) continue # Check instances created and running - saved_instances = cloud_load_instance(configuration, user_id, - cloud_id, keyword_all) - - instance_fields = ['public_fqdn', 'status'] - status_map = status_all_cloud_instances(configuration, user_id, - cloud_id, cloud_flavor, - list(saved_instances), - instance_fields) - for (instance_id, instance_dict) in saved_instances.items(): - instance_label = instance_dict.get('INSTANCE_LABEL', - instance_id) - instance_status = status_map[instance_id].get('status', - "UNKNOWN") - if instance_status in ['stopped']: + saved_instances = cloud_load_instance( + configuration, user_id, cloud_id, keyword_all + ) + + instance_fields = ["public_fqdn", "status"] + status_map = status_all_cloud_instances( + configuration, + user_id, + cloud_id, + cloud_flavor, + list(saved_instances), + instance_fields, + ) + for instance_id, instance_dict in saved_instances.items(): + instance_label = instance_dict.get( + "INSTANCE_LABEL", instance_id + ) + instance_status = status_map[instance_id].get( + "status", "UNKNOWN" + ) + if instance_status in ["stopped"]: if verbose: - print('Skip stopped %s instance %s for %s' % - (cloud_title, instance_id, user_id)) + print( + "Skip stopped %s instance %s for %s" + % (cloud_title, instance_id, user_id) + ) continue else: - if not 'cloud' in affected: - affected.append('cloud') + if not "cloud" in affected: + affected.append("cloud") - (_, username, full_name, addresses, errors) = user_account_notify( - user_id, raw_targets, conf_path, db_path, verbose, admin_copy) + _, username, full_name, addresses, errors = user_account_notify( + user_id, raw_targets, conf_path, db_path, verbose, admin_copy + ) if errors: print("Address lookup errors for %s :" % user_id) - print('\n'.join(errors)) + print("\n".join(errors)) exit_code += 1 continue if not username: @@ -341,20 +386,29 @@ def usage(name='notifyexpire.py'): print("No affected services for %s" % user_id) continue - expire = datetime.datetime.fromtimestamp(user_dict['expire']) - print("Account %s expires on %s - affected services: %s" % - (user_id, expire, ', '.join(affected))) - notify_dict = {'JOB_ID': 'NOJOBID', 'USER_CERT': user_id, 'NOTIFY': []} - for (proto, address_list) in addresses.items(): + expire = datetime.datetime.fromtimestamp(user_dict["expire"]) + print( + "Account %s expires on %s - affected services: %s" + % (user_id, expire, ", ".join(affected)) + ) + notify_dict = {"JOB_ID": "NOJOBID", "USER_CERT": user_id, "NOTIFY": []} + for proto, address_list in addresses.items(): for address in address_list: - notify_dict['NOTIFY'].append('%s: %s' % (proto, address)) + notify_dict["NOTIFY"].append("%s: %s" % (proto, address)) # Don't actually send unless requested if not raw_targets and not admin_copy: continue - print("Send account expire warning for '%s' to:\n%s" - % (user_id, '\n'.join(notify_dict['NOTIFY']))) - notify_user(notify_dict, [user_id, username, full_name, user_dict, - affected], 'ACCOUNTEXPIRE', logger, '', - configuration) + print( + "Send account expire warning for '%s' to:\n%s" + % (user_id, "\n".join(notify_dict["NOTIFY"])) + ) + notify_user( + notify_dict, + [user_id, username, full_name, user_dict, affected], + "ACCOUNTEXPIRE", + logger, + "", + configuration, + ) sys.exit(exit_code) diff --git a/mig/server/notifymigoid.py b/mig/server/notifymigoid.py index 6dd5974aa..7b1bbd47a 100755 --- a/mig/server/notifymigoid.py +++ b/mig/server/notifymigoid.py @@ -34,8 +34,7 @@ configured additional messaging protocols they can also be used. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import sys @@ -45,7 +44,7 @@ from mig.shared.useradm import init_user_adm, user_account_notify -def usage(name='notifymigoid.py'): +def usage(name="notifymigoid.py"): """Usage help""" print("""Send internal OpenID account create/renew intro to user from user @@ -65,55 +64,55 @@ def usage(name='notifymigoid.py'): One or more destinations may be set by combining multiple -e, -s and -a options. -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None verbose = False admin_copy = False raw_targets = {} user_id = None - opt_args = 'ac:Cd:e:hI:s:v' + opt_args = "ac:Cd:e:hI:s:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(keyword_auto) - elif opt == '-c': + for opt, val in opts: + if opt == "-a": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(keyword_auto) + elif opt == "-c": conf_path = val - elif opt == '-C': + elif opt == "-C": admin_copy = True - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(val) - elif opt == '-h': + elif opt == "-e": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(val) + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': + elif opt == "-I": user_id = val - elif opt == '-s': + elif opt == "-s": val = val.lower() raw_targets[val] = raw_targets.get(val, []) - raw_targets[val].append('SETTINGS') - elif opt == '-v': + raw_targets[val].append("SETTINGS") + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) @@ -121,13 +120,13 @@ def usage(name='notifymigoid.py'): print("No user_id provided!") sys.exit(1) - (configuration, username, full_name, addresses, errors) = \ - user_account_notify(user_id, raw_targets, conf_path, db_path, verbose, - admin_copy) + configuration, username, full_name, addresses, errors = user_account_notify( + user_id, raw_targets, conf_path, db_path, verbose, admin_copy + ) if errors: print("Address lookup errors:") - print('\n'.join(errors)) + print("\n".join(errors)) if not addresses: print("Error: found no suitable addresses") @@ -136,11 +135,19 @@ def usage(name='notifymigoid.py'): print("Error: found no username") sys.exit(1) logger = configuration.logger - notify_dict = {'JOB_ID': 'NOJOBID', 'USER_CERT': user_id, 'NOTIFY': []} - for (proto, address_list) in addresses.items(): + notify_dict = {"JOB_ID": "NOJOBID", "USER_CERT": user_id, "NOTIFY": []} + for proto, address_list in addresses.items(): for address in address_list: - notify_dict['NOTIFY'].append('%s: %s' % (proto, address)) - print("Sending internal OpenID account intro for '%s' to:\n%s" % - (user_id, '\n'.join(notify_dict['NOTIFY']))) - notify_user(notify_dict, [user_id, username, full_name], 'ACCOUNTINTRO', - logger, '', configuration) + notify_dict["NOTIFY"].append("%s: %s" % (proto, address)) + print( + "Sending internal OpenID account intro for '%s' to:\n%s" + % (user_id, "\n".join(notify_dict["NOTIFY"])) + ) + notify_user( + notify_dict, + [user_id, username, full_name], + "ACCOUNTINTRO", + logger, + "", + configuration, + ) diff --git a/mig/server/readconfval.py b/mig/server/readconfval.py index 6b4064098..1214873a2 100755 --- a/mig/server/readconfval.py +++ b/mig/server/readconfval.py @@ -30,8 +30,7 @@ other components outside the actual python code. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -40,7 +39,7 @@ from mig.shared.conf import get_configuration_object -def usage(name='readconfval.py'): +def usage(name="readconfval.py"): """Usage help""" print("""Lookup a evaluated configuration value using MiGserver.conf. @@ -51,44 +50,44 @@ def usage(name='readconfval.py'): -f Force operations to continue past errors -h Show this help -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: +if "__main__" == __name__: args = sys.argv[1:] conf_path = None force = False verbose = False - opt_args = 'c:fhv' + opt_args = "c:fhv" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") if len(args) == 1: name = args[0] @@ -97,12 +96,12 @@ def usage(name='readconfval.py'): sys.exit(1) if verbose: - print('Lookup configuration value for %s' % name) + print("Lookup configuration value for %s" % name) retval = 42 try: configuration = get_configuration_object(skip_log=True) - val = getattr(configuration, name, 'UNKNOWN') - if val != 'UNKNOWN': + val = getattr(configuration, name, "UNKNOWN") + if val != "UNKNOWN": retval = 0 print("%s" % val) except Exception as err: diff --git a/mig/server/refreshusers.py b/mig/server/refreshusers.py index 11f1e14f9..d9e7c6de1 100755 --- a/mig/server/refreshusers.py +++ b/mig/server/refreshusers.py @@ -39,13 +39,22 @@ import sys import time -from mig.shared.defaults import AUTH_EXT_OID, AUTH_EXT_OIDC, AUTH_MIG_CERT, \ - AUTH_MIG_OID, gdp_distinguished_field -from mig.shared.useradm import assure_current_htaccess, default_search, \ - init_user_adm, search_users - - -def usage(name='refreshusers.py'): +from mig.shared.defaults import ( + AUTH_EXT_OID, + AUTH_EXT_OIDC, + AUTH_MIG_CERT, + AUTH_MIG_OID, + gdp_distinguished_field, +) +from mig.shared.useradm import ( + assure_current_htaccess, + default_search, + init_user_adm, + search_users, +) + + +def usage(name="refreshusers.py"): """Usage help.""" print("""Refresh MiG user user files and dirs based on user ID in MiG user database. @@ -62,12 +71,11 @@ def usage(name='refreshusers.py'): -I CERT_DN Filter to user(s) with ID (distinguished name) -s SHORT_ID Filter to user(s) with given short ID field -v Verbose output -""" - % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, _, db_path) = init_user_adm() +if "__main__" == __name__: + args, _, db_path = init_user_adm() conf_path = None force = False verbose = False @@ -75,127 +83,144 @@ def usage(name='refreshusers.py'): now = int(time.time()) search_filter = default_search() # Default to all users with expire range between now and in 30 days - search_filter['distinguished_name'] = '*' - search_filter['short_id'] = '*' - search_filter['expire_after'] = now - search_filter['expire_before'] = int(time.time() + 365 * 24 * 3600) - opt_args = 'A:B:c:d:fhI:s:v' + search_filter["distinguished_name"] = "*" + search_filter["short_id"] = "*" + search_filter["expire_after"] = now + search_filter["expire_before"] = int(time.time() + 365 * 24 * 3600) + opt_args = "A:B:c:d:fhI:s:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-A': + for opt, val in opts: + if opt == "-A": after = now - if val.startswith('+'): + if val.startswith("+"): after += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): after -= int(val[1:]) else: after = int(val) - search_filter['expire_after'] = after - elif opt == '-B': + search_filter["expire_after"] = after + elif opt == "-B": before = now - if val.startswith('+'): + if val.startswith("+"): before += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): before -= int(val[1:]) else: before = int(val) - search_filter['expire_before'] = before - elif opt == '-c': + search_filter["expire_before"] = before + elif opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-s': - search_filter['short_id'] = val - elif opt == '-v': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-s": + search_filter["short_id"] = val + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) sys.exit(1) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose) + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose + ) logger = configuration.logger gdp_prefix = "%s=" % gdp_distinguished_field # NOTE: we already filtered expired accounts here - search_dn = search_filter['distinguished_name'] + search_dn = search_filter["distinguished_name"] before_date = datetime.datetime.fromtimestamp( - search_filter['expire_before']) - after_date = datetime.datetime.fromtimestamp(search_filter['expire_after']) + search_filter["expire_before"] + ) + after_date = datetime.datetime.fromtimestamp(search_filter["expire_after"]) if verbose: if hits: - print("Check %d account(s) expiring between %s and %s for ID %r" % - (len(hits), after_date, before_date, search_dn)) + print( + "Check %d account(s) expiring between %s and %s for ID %r" + % (len(hits), after_date, before_date, search_dn) + ) else: - print("No accounts expire between %s and %s for ID %r" % - (after_date, before_date, search_dn)) + print( + "No accounts expire between %s and %s for ID %r" + % (after_date, before_date, search_dn) + ) - for (user_id, user_dict) in hits: + for user_id, user_dict in hits: if verbose: - print('Check refresh needed for %r' % user_id) + print("Check refresh needed for %r" % user_id) # NOTE: gdp accounts don't actually use .htaccess but cat.py serving - if configuration.site_enable_gdp and \ - user_id.split('/')[-1].startswith(gdp_prefix): + if configuration.site_enable_gdp and user_id.split("/")[-1].startswith( + gdp_prefix + ): if verbose: - print("Handling GDP project account %r despite no effect" % - user_id) + print( + "Handling GDP project account %r despite no effect" + % user_id + ) # Don't warn about already disabled or suspended accounts - account_state = user_dict.get('status', 'active') - if account_state not in ('active', 'temporal'): + account_state = user_dict.get("status", "active") + if account_state not in ("active", "temporal"): if verbose: - print('Skip handling of already %s user %r' % (account_state, - user_id)) + print( + "Skip handling of already %s user %r" + % (account_state, user_id) + ) continue - known_auth = user_dict.get('auth', []) + known_auth = user_dict.get("auth", []) if not known_auth: - if user_dict.get('main_id', ''): + if user_dict.get("main_id", ""): known_auth.append(AUTH_EXT_OIDC) - elif user_dict.get('openid_names', []): - if user_dict.get('password_hash', ''): + elif user_dict.get("openid_names", []): + if user_dict.get("password_hash", ""): known_auth.append(AUTH_MIG_OID) else: known_auth.append(AUTH_EXT_OID) - elif user_dict.get('password', ''): + elif user_dict.get("password", ""): known_auth.append(AUTH_MIG_CERT) else: if verbose: - print('Skip handling of user %r without auth info' % - user_id) + print( + "Skip handling of user %r without auth info" % user_id + ) continue # The auth list changed at one point so it may contain alias or name - oid_auth = [i for i in known_auth if i in [AUTH_EXT_OID, AUTH_EXT_OIDC, - 'extoid', 'extoidc']] + oid_auth = [ + i + for i in known_auth + if i in [AUTH_EXT_OID, AUTH_EXT_OIDC, "extoid", "extoidc"] + ] if not oid_auth: if verbose: - print('Skip handling of user %r without extoid(c) auth' % - user_id) + print( + "Skip handling of user %r without extoid(c) auth" % user_id + ) continue if verbose: - print('Assure current htaccess for %r account' % user_id) - if not assure_current_htaccess(configuration, user_id, user_dict, - force, verbose): + print("Assure current htaccess for %r account" % user_id) + if not assure_current_htaccess( + configuration, user_id, user_dict, force, verbose + ): exit_code += 1 sys.exit(exit_code) diff --git a/mig/server/rejectuser.py b/mig/server/rejectuser.py index e4eeb785a..cc1d82558 100755 --- a/mig/server/rejectuser.py +++ b/mig/server/rejectuser.py @@ -27,8 +27,7 @@ """Reject a MiG user request and send email with reason""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -41,7 +40,7 @@ from mig.shared.useradm import init_user_adm, user_request_reject -def usage(name='rejectuser.py'): +def usage(name="rejectuser.py"): """Usage help""" print("""Reject MiG user account request and inform by email. @@ -60,64 +59,64 @@ def usage(name='rejectuser.py'): -s Skip automatic notification to email in user request -u USER_FILE Read user information from pickle file -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() - auth_type = 'oid' +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() + auth_type = "oid" conf_path = None force = False reason = "invalid or missing mandatory info" verbose = False admin_copy = False user_copy = True - raw_targets = {'email': []} + raw_targets = {"email": []} user_file = None user_id = None user_dict = {} override_fields = {} - opt_args = 'a:c:Cd:e:fhi:r:su:v' + opt_args = "a:c:Cd:e:fhi:r:su:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': + for opt, val in opts: + if opt == "-a": auth_type = val - elif opt == '-c': + elif opt == "-c": conf_path = val - elif opt == '-C': + elif opt == "-C": admin_copy = True - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': - raw_targets['email'].append(val) - elif opt == '-f': + elif opt == "-e": + raw_targets["email"].append(val) + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-r': + elif opt == "-r": reason = val - elif opt == '-s': + elif opt == "-s": user_copy = False - elif opt == '-u': + elif opt == "-u": user_file = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(1) if args: - print('Error: Non-option arguments are not supported - missing quotes?') + print("Error: Non-option arguments are not supported - missing quotes?") usage() sys.exit(1) @@ -126,53 +125,63 @@ def usage(name='rejectuser.py'): sys.exit(1) if auth_type not in valid_auth_types: - print('Error: invalid account auth type %r requested (allowed: %s)' % - (auth_type, ', '.join(valid_auth_types))) + print( + "Error: invalid account auth type %r requested (allowed: %s)" + % (auth_type, ", ".join(valid_auth_types)) + ) usage() sys.exit(1) try: user_dict = load(user_file) except Exception as err: - print('Error in user name extraction: %s' % err) + print("Error in user name extraction: %s" % err) usage() sys.exit(1) if user_id: - user_dict['distinguished_name'] = user_id - elif 'distinguished_name' not in user_dict: + user_dict["distinguished_name"] = user_id + elif "distinguished_name" not in user_dict: fill_distinguished_name(user_dict) fill_user(user_dict) - user_id = user_dict['distinguished_name'] + user_id = user_dict["distinguished_name"] # Optionally inform mail used in request if user_copy: - raw_targets['email'].append(user_dict.get('email', keyword_auto)) + raw_targets["email"].append(user_dict.get("email", keyword_auto)) # Now all user fields are set and we can reject and warn the user - (configuration, addresses, errors) = \ - user_request_reject(user_id, raw_targets, conf_path, - db_path, verbose, admin_copy) + configuration, addresses, errors = user_request_reject( + user_id, raw_targets, conf_path, db_path, verbose, admin_copy + ) if errors: print("Address lookup errors:") - print('\n'.join(errors)) + print("\n".join(errors)) if not addresses: print("Error: found no suitable addresses") sys.exit(1) logger = configuration.logger - notify_dict = {'JOB_ID': 'NOJOBID', 'USER_CERT': user_id, 'NOTIFY': []} - for (proto, address_list) in addresses.items(): + notify_dict = {"JOB_ID": "NOJOBID", "USER_CERT": user_id, "NOTIFY": []} + for proto, address_list in addresses.items(): for address in address_list: - notify_dict['NOTIFY'].append('%s: %s' % (proto, address)) - print("Sending reject account request for '%s' to:\n%s" % - (user_id, '\n'.join(notify_dict['NOTIFY']))) - notify_user(notify_dict, [user_id, user_dict, auth_type, reason], - 'ACCOUNTREQUESTREJECT', logger, '', configuration) + notify_dict["NOTIFY"].append("%s: %s" % (proto, address)) + print( + "Sending reject account request for '%s' to:\n%s" + % (user_id, "\n".join(notify_dict["NOTIFY"])) + ) + notify_user( + notify_dict, + [user_id, user_dict, auth_type, reason], + "ACCOUNTREQUESTREJECT", + logger, + "", + configuration, + ) if verbose: - print('Cleaning up tmp file: %s' % user_file) + print("Cleaning up tmp file: %s" % user_file) os.remove(user_file) diff --git a/mig/server/reqacceptpeer.py b/mig/server/reqacceptpeer.py index 5d78de1dc..cf27fcacf 100755 --- a/mig/server/reqacceptpeer.py +++ b/mig/server/reqacceptpeer.py @@ -33,26 +33,32 @@ configured additional messaging protocols they can also be used. """ -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os import sys -from mig.shared.accountreq import peers_permit_allowed, manage_pending_peers -from mig.shared.base import fill_distinguished_name, client_id_dir +from mig.shared.accountreq import manage_pending_peers, peers_permit_allowed +from mig.shared.base import client_id_dir, fill_distinguished_name from mig.shared.conf import get_configuration_object -from mig.shared.defaults import keyword_auto, gdp_distinguished_field, \ - pending_peers_filename +from mig.shared.defaults import ( + gdp_distinguished_field, + keyword_auto, + pending_peers_filename, +) from mig.shared.notification import notify_user -from mig.shared.serial import load, dump -from mig.shared.useradm import init_user_adm, search_users, default_search, \ - user_account_notify +from mig.shared.serial import dump, load +from mig.shared.useradm import ( + default_search, + init_user_adm, + search_users, + user_account_notify, +) from mig.shared.validstring import valid_email_addresses -def usage(name='reqacceptpeer.py'): +def usage(name="reqacceptpeer.py"): """Usage help""" print("""Request formal acceptance of external account request to user(s) @@ -80,11 +86,11 @@ def usage(name='reqacceptpeer.py'): One or more destinations may be set by combining multiple -e, -s and -a options. -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None verbose = False admin_copy = False @@ -93,154 +99,164 @@ def usage(name='reqacceptpeer.py'): user_id = None search_filter = default_search() # IMPORTANT: Default to nobody to avoid spam if called without -I CLIENT_ID - search_filter['distinguished_name'] = '' + search_filter["distinguished_name"] = "" peer_dict = {} regex_keys = [] exit_code = 0 - opt_args = 'ac:Cd:e:E:hI:s:u:v' + opt_args = "ac:Cd:e:E:hI:s:u:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(keyword_auto) - elif opt == '-c': + for opt, val in opts: + if opt == "-a": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(keyword_auto) + elif opt == "-c": conf_path = val - elif opt == '-C': + elif opt == "-C": admin_copy = True - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-e': - raw_targets['email'] = raw_targets.get('email', []) - raw_targets['email'].append(val) - elif opt == '-E': + elif opt == "-e": + raw_targets["email"] = raw_targets.get("email", []) + raw_targets["email"].append(val) + elif opt == "-E": if val != keyword_auto: - search_filter['email'] = val.lower() + search_filter["email"] = val.lower() else: - search_filter['email'] = keyword_auto - elif opt == '-h': + search_filter["email"] = keyword_auto + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-s': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-s": val = val.lower() raw_targets[val] = raw_targets.get(val, []) - raw_targets[val].append('SETTINGS') - elif opt == '-u': + raw_targets[val].append("SETTINGS") + elif opt == "-u": user_file = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: if verbose: - print('using configuration in %s' % conf_path) + print("using configuration in %s" % conf_path) else: if verbose: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object(config_file=conf_path) logger = configuration.logger if user_file and args: - print('Error: Only one kind of user specification allowed at a time') + print("Error: Only one kind of user specification allowed at a time") usage() sys.exit(1) if args: try: - peer_dict['full_name'] = args[0] - peer_dict['organization'] = args[1] - peer_dict['state'] = args[2] - peer_dict['country'] = args[3] - peer_dict['email'] = args[4] - peer_dict['comment'] = args[5] + peer_dict["full_name"] = args[0] + peer_dict["organization"] = args[1] + peer_dict["state"] = args[2] + peer_dict["country"] = args[3] + peer_dict["email"] = args[4] + peer_dict["comment"] = args[5] except IndexError: - print('Error: too few arguments given (expected 6 got %d)' - % len(args)) + print( + "Error: too few arguments given (expected 6 got %d)" % len(args) + ) usage() sys.exit(1) elif user_file: try: peer_dict = load(user_file) except Exception as err: - print('Error in user name extraction: %s' % err) + print("Error in user name extraction: %s" % err) usage() sys.exit(1) else: - print('No peer specified: please pass peer as args or with -u PATH') + print("No peer specified: please pass peer as args or with -u PATH") usage() sys.exit(1) fill_distinguished_name(peer_dict) - peer_id = peer_dict['distinguished_name'] + peer_id = peer_dict["distinguished_name"] - if search_filter['email'] == keyword_auto: + if search_filter["email"] == keyword_auto: peer_email_list = [] # Extract email of peers contact from explicit peers field or comment # We don't try peers full name here as it is far too tricky to match - peers_email = peer_dict.get('peers_email', '') - comment = peer_dict.get('comment', '') + peers_email = peer_dict.get("peers_email", "") + comment = peer_dict.get("comment", "") peers_source = "%s\n%s" % (peers_email, comment) peer_emails = valid_email_addresses(configuration, peers_source) if peer_emails[1:]: - regex_keys.append('email') - search_filter['email'] = '(' + '|'.join(peer_emails) + ')' + regex_keys.append("email") + search_filter["email"] = "(" + "|".join(peer_emails) + ")" elif peer_emails: - search_filter['email'] = peer_emails[0] - elif search_filter['distinguished_name']: - search_filter['email'] = '*' + search_filter["email"] = peer_emails[0] + elif search_filter["distinguished_name"]: + search_filter["email"] = "*" else: - search_filter['email'] = '' + search_filter["email"] = "" # If email is provided or detected DN may be almost anything - if search_filter['email'] and not search_filter['distinguished_name']: - search_filter['distinguished_name'] = '*emailAddress=*' + if search_filter["email"] and not search_filter["distinguished_name"]: + search_filter["distinguished_name"] = "*emailAddress=*" if verbose: - print('Handling peer %s request to users matching %s' % - (peer_id, search_filter)) + print( + "Handling peer %s request to users matching %s" + % (peer_id, search_filter) + ) # Lookup users to request formal acceptance from - (_, hits) = search_users(search_filter, conf_path, - db_path, verbose, regex_match=regex_keys) + _, hits = search_users( + search_filter, conf_path, db_path, verbose, regex_match=regex_keys + ) logger = configuration.logger gdp_prefix = "%s=" % gdp_distinguished_field if len(hits) < 1: print( - "Aborting attempt to request peer acceptance without target users") + "Aborting attempt to request peer acceptance without target users" + ) print(" ... did you forget or supply too rigid -E EMAIL or -I DN arg?") sys.exit(1) elif len(hits) > 3: - print("Aborting attempt to request peer acceptance from %d users!" % - len(hits)) + print( + "Aborting attempt to request peer acceptance from %d users!" + % len(hits) + ) print(" ... did you supply too lax -E EMAIL or -I DN argument?") sys.exit(1) else: if verbose: - print("Attempt to request peer acceptance from users: %s" % - '\n'.join([i[0] for i in hits])) + print( + "Attempt to request peer acceptance from users: %s" + % "\n".join([i[0] for i in hits]) + ) - for (user_id, user_dict) in hits: + for user_id, user_dict in hits: if verbose: - print('Check for %s' % user_id) + print("Check for %s" % user_id) - if configuration.site_enable_gdp and \ - user_id.split('/')[-1].startswith(gdp_prefix): + if configuration.site_enable_gdp and user_id.split("/")[-1].startswith( + gdp_prefix + ): if verbose: print("Skip GDP project account: %s" % user_id) continue @@ -253,36 +269,52 @@ def usage(name='reqacceptpeer.py'): print("Skip account %s without vouching permission" % user_id) continue - if not manage_pending_peers(configuration, user_id, "add", - [(peer_id, peer_dict)]): - print("Failed to forward accept peer %s to %s" % - (peer_id, user_id)) + if not manage_pending_peers( + configuration, user_id, "add", [(peer_id, peer_dict)] + ): + print("Failed to forward accept peer %s to %s" % (peer_id, user_id)) continue print("Added peer request from %s to %s" % (peer_id, user_id)) - (_, _, full_name, addresses, errors) = user_account_notify( - user_id, raw_targets, conf_path, db_path, verbose, admin_copy) + _, _, full_name, addresses, errors = user_account_notify( + user_id, raw_targets, conf_path, db_path, verbose, admin_copy + ) if errors: print("Address lookup errors for %s :" % user_id) - print('\n'.join(errors)) + print("\n".join(errors)) exit_code += 1 continue - notify_dict = {'JOB_ID': 'NOJOBID', 'USER_CERT': user_id, 'NOTIFY': []} - for (proto, address_list) in addresses.items(): + notify_dict = {"JOB_ID": "NOJOBID", "USER_CERT": user_id, "NOTIFY": []} + for proto, address_list in addresses.items(): for address in address_list: - notify_dict['NOTIFY'].append('%s: %s' % (proto, address)) + notify_dict["NOTIFY"].append("%s: %s" % (proto, address)) # Don't actually send unless requested if not raw_targets and not admin_copy: - print("No email targets for request accept peer %s from %s" % - (peer_id, user_id)) + print( + "No email targets for request accept peer %s from %s" + % (peer_id, user_id) + ) continue - print("Send request accept peer message for '%s' to:\n%s" - % (peer_id, '\n'.join(notify_dict['NOTIFY']))) - notify_user(notify_dict, [peer_id, configuration.short_title, - 'peeraccount', peer_dict['comment'], - peer_dict['email'], user_id], - 'SENDREQUEST', logger, '', configuration) + print( + "Send request accept peer message for '%s' to:\n%s" + % (peer_id, "\n".join(notify_dict["NOTIFY"])) + ) + notify_user( + notify_dict, + [ + peer_id, + configuration.short_title, + "peeraccount", + peer_dict["comment"], + peer_dict["email"], + user_id, + ], + "SENDREQUEST", + logger, + "", + configuration, + ) sys.exit(exit_code) diff --git a/mig/server/reset2fakey.py b/mig/server/reset2fakey.py index 12ab1b820..db33542e7 100755 --- a/mig/server/reset2fakey.py +++ b/mig/server/reset2fakey.py @@ -27,24 +27,26 @@ """(Re)set user 2FA key""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from builtins import range import base64 import datetime import getopt import os import sys import tempfile +from builtins import range import pyotp from mig.shared.auth import reset_twofactor_key, valid_otp_window from mig.shared.base import client_id_dir -from mig.shared.defaults import twofactor_filename, twofactor_key_name, \ - twofactor_interval_name from mig.shared.conf import get_configuration_object +from mig.shared.defaults import ( + twofactor_filename, + twofactor_interval_name, + twofactor_key_name, +) from mig.shared.fileio import delete_file from mig.shared.gdp.all import project_close from mig.shared.settings import load_twofactor, parse_and_save_twofactor @@ -61,23 +63,24 @@ def enable2fa(configuration, user_id, verbose, force=False): if current_twofactor_dict: return True keywords_dict = twofactor_keywords(configuration) - topic_mrsl = '' + topic_mrsl = "" for keyword in keywords_dict: - topic_mrsl += '''::%s:: + topic_mrsl += """::%s:: %s -''' % (keyword.upper(), 'True') +""" % (keyword.upper(), "True") try: - (filehandle, tmptopicfile) = tempfile.mkstemp(text=True) + filehandle, tmptopicfile = tempfile.mkstemp(text=True) os.write(filehandle, topic_mrsl) os.close(filehandle) except Exception as exc: msg = "Error: Problem writing temporary topic file on server." print("%s : %s" % (msg, exc)) return False - (parse_status, _) = parse_and_save_twofactor(tmptopicfile, user_id, - configuration) + parse_status, _ = parse_and_save_twofactor( + tmptopicfile, user_id, configuration + ) if parse_status: print("Enabled all two-factor services for user: %r" % user_id) else: @@ -120,13 +123,14 @@ def remove2fa(configuration, user_id, verbose, force=False): twofactor_settings_path = os.path.join(settings_dir, twofactor_filename) if verbose: print("Removing twofactor file: %s" % twofactor_settings_path) - status = delete_file(twofactor_settings_path, _logger, - allow_missing=allow_missing) + status = delete_file( + twofactor_settings_path, _logger, allow_missing=allow_missing + ) return status -def usage(name='reset2fakey.py'): +def usage(name="reset2fakey.py"): """Usage help""" print("""(Re)set user 2FA key. @@ -140,13 +144,12 @@ def usage(name='reset2fakey.py'): -a Enable 2fa for all services -r Remove 2fa for all services -v Verbose output -""" - % {'name': name}) +""" % {"name": name}) # ## Main ### -if '__main__' == __name__: +if "__main__" == __name__: conf_path = None force = False verbose = False @@ -156,62 +159,62 @@ def usage(name='reset2fakey.py'): seed_file = None interval = None remove = False - opt_args = 'c:fhari:v' + opt_args = "c:fhari:v" try: - (opts, args) = getopt.getopt(sys.argv[1:], opt_args) + opts, args = getopt.getopt(sys.argv[1:], opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-i': + elif opt == "-i": user_id = val - elif opt == '-a': + elif opt == "-a": enable_all = True - elif opt == '-r': + elif opt == "-r": remove = True - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - os.environ['MIG_CONF'] = conf_path - print('using configuration in %s' % conf_path) + os.environ["MIG_CONF"] = conf_path + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object(skip_log=True) if not configuration.site_enable_twofactor: - print('Error: Two-factor authentication disabled for site') + print("Error: Two-factor authentication disabled for site") sys.exit(1) if not user_id: - print('Error: Existing user ID is required') + print("Error: Existing user ID is required") usage() sys.exit(1) if configuration.site_enable_gdp: - status = project_close(configuration, - 'https', - '127.0.0.1', - user_id=user_id) + status = project_close( + configuration, "https", "127.0.0.1", user_id=user_id + ) if not status: print( - 'Warning: Project close failed, user probably not logged in to any projects') + "Warning: Project close failed, user probably not logged in to any projects" + ) if remove: status = remove2fa(configuration, user_id, verbose, force) @@ -223,7 +226,7 @@ def usage(name='reset2fakey.py'): sys.exit(1) if not enable2fa(configuration, user_id, verbose, enable_all): - print('Error: Failed to enable two-factor authentication') + print("Error: Failed to enable two-factor authentication") sys.exit(1) if args: @@ -238,7 +241,7 @@ def usage(name='reset2fakey.py'): if seed_file: # TODO: port to read_file helper try: - s_fd = open(seed_file, 'r') + s_fd = open(seed_file, "r") seed = s_fd.read().strip() s_fd.close() except Exception as exc: @@ -275,8 +278,9 @@ def usage(name='reset2fakey.py'): if interval: print("using interval: %s" % interval) - twofa_key = reset_twofactor_key(user_id, configuration, - seed=seed, interval=interval) + twofa_key = reset_twofactor_key( + user_id, configuration, seed=seed, interval=interval + ) if verbose: print("New two factor key: %s" % twofa_key) @@ -290,19 +294,30 @@ def usage(name='reset2fakey.py'): totp_custom_totp = pyotp.TOTP(twofa_key, interval=interval) if valid_otp_window == 0: - print("default interval, code: %s" - % totp_default.at(current_time, 0)) + print( + "default interval, code: %s" + % totp_default.at(current_time, 0) + ) if totp_custom_totp: - print("interval: %d, code: %s" - % (interval, totp_custom_totp.at(current_time, 0))) + print( + "interval: %d, code: %s" + % (interval, totp_custom_totp.at(current_time, 0)) + ) else: for i in range(-valid_otp_window, valid_otp_window + 1): - print("default interval, window: %d, code: %s" - % (i, totp_default.at(current_time, i))) + print( + "default interval, window: %d, code: %s" + % (i, totp_default.at(current_time, i)) + ) if totp_custom_totp: - print("interval: %d, window: %d, code: %s" - % (interval, i, - totp_custom_totp.at(current_time, i))) + print( + "interval: %d, window: %d, code: %s" + % ( + interval, + i, + totp_custom_totp.at(current_time, i), + ) + ) else: print("Failed to reset two factor key") sys.exit(1) diff --git a/mig/server/resetcaches.py b/mig/server/resetcaches.py index c25230d45..97c081641 100644 --- a/mig/server/resetcaches.py +++ b/mig/server/resetcaches.py @@ -32,36 +32,44 @@ import os import sys - from mig.shared.base import client_id_dir from mig.shared.conf import get_configuration_object from mig.shared.fileio import delete_file -from mig.shared.vgridaccess import refresh_vgrid_map, refresh_user_map, \ - refresh_resource_map +from mig.shared.vgridaccess import ( + refresh_resource_map, + refresh_user_map, + refresh_vgrid_map, +) -def refresh_maps(configuration, map_list, verbose, force=False, allow_missing=True): +def refresh_maps( + configuration, map_list, verbose, force=False, allow_missing=True +): """Make sure one or more of vgrid, user and resource maps are refreshed""" _logger = configuration.logger status = True for map_name in map_list: - for root in (configuration.mig_system_files, configuration.mig_system_run): - for ext in ('map', 'lock', 'modified'): + for root in ( + configuration.mig_system_files, + configuration.mig_system_run, + ): + for ext in ("map", "lock", "modified"): sub_path = os.path.join(root, "%s.%s" % (map_name, ext)) if verbose: print("Removing %s" % sub_path) status = delete_file( - sub_path, _logger, allow_missing=allow_missing) + sub_path, _logger, allow_missing=allow_missing + ) if not status and not force: return status - if map_name == 'vgrid': + if map_name == "vgrid": if not refresh_vgrid_map(configuration) and not force: return status - elif map_name == 'user': + elif map_name == "user": if not refresh_user_map(configuration) and not force: return status - elif map_name == 'resource': + elif map_name == "resource": if not refresh_resource_map(configuration) and not force: return status else: @@ -69,7 +77,7 @@ def refresh_maps(configuration, map_list, verbose, force=False, allow_missing=Tr return status -def usage(name='resetcaches.py'): +def usage(name="resetcaches.py"): """Usage help""" print("""(Re)set vgrid, user and resource map caches. @@ -81,57 +89,59 @@ def usage(name='resetcaches.py'): -h Show this help -v Verbose output and MAP_NAME one or more of vgrid, user and resource. -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: +if "__main__" == __name__: conf_path = None force = False verbose = False - opt_args = 'c:fhv' + opt_args = "c:fhv" try: - (opts, args) = getopt.getopt(sys.argv[1:], opt_args) + opts, args = getopt.getopt(sys.argv[1:], opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) if verbose: if conf_path: - os.environ['MIG_CONF'] = conf_path - print('using configuration in %s' % conf_path) + os.environ["MIG_CONF"] = conf_path + print("using configuration in %s" % conf_path) else: - print('using configuration from MIG_CONF (or default)') + print("using configuration from MIG_CONF (or default)") configuration = get_configuration_object(skip_log=True) if args: map_list = args else: - map_list = ['vgrid', 'user', 'resource'] + map_list = ["vgrid", "user", "resource"] if not refresh_maps(configuration, map_list, verbose, force): - print("Failed to refresh %s map(s) - force may be needed?" % - ', '.join(map_list)) + print( + "Failed to refresh %s map(s) - force may be needed?" + % ", ".join(map_list) + ) sys.exit(1) if verbose: - print("Refreshed %s maps" % ', '.join(map_list)) + print("Refreshed %s maps" % ", ".join(map_list)) sys.exit(0) diff --git a/mig/server/searchusers.py b/mig/server/searchusers.py index 619c725bc..eb9926fcf 100755 --- a/mig/server/searchusers.py +++ b/mig/server/searchusers.py @@ -27,18 +27,18 @@ """Find all users with given data base field(s)""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from past.builtins import basestring import getopt import sys import time -from mig.shared.useradm import init_user_adm, search_users, default_search +from past.builtins import basestring + +from mig.shared.useradm import default_search, init_user_adm, search_users -def usage(name='searchusers.py'): +def usage(name="searchusers.py"): """Usage help""" print("""Search in MiG user database. @@ -62,73 +62,74 @@ def usage(name='searchusers.py'): -v Verbose output Each search value can be a string or a pattern with * and ? as wildcards. -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None verbose = False user_dict = {} - opt_args = 'a:b:c:C:d:E:f:F:hI:nO:r:S:v' + opt_args = "a:b:c:C:d:E:f:F:hI:nO:r:S:v" search_filter = default_search() expire_before, expire_after = None, None only_fields = [] try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': - search_filter['expire_after'] = int(val) - elif opt == '-b': - search_filter['expire_before'] = int(val) - elif opt == '-c': + for opt, val in opts: + if opt == "-a": + search_filter["expire_after"] = int(val) + elif opt == "-b": + search_filter["expire_before"] = int(val) + elif opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": only_fields.append(val) - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': - search_filter['distinguished_name'] = val - elif opt == '-n': - only_fields.append('full_name') - elif opt == '-C': - search_filter['country'] = val - elif opt == '-E': - search_filter['email'] = val - elif opt == '-F': - search_filter['full_name'] = val - elif opt == '-O': - search_filter['organization'] = val - elif opt == '-r': - search_filter['role'] = val - elif opt == '-S': - search_filter['state'] = val - elif opt == '-v': + elif opt == "-I": + search_filter["distinguished_name"] = val + elif opt == "-n": + only_fields.append("full_name") + elif opt == "-C": + search_filter["country"] = val + elif opt == "-E": + search_filter["email"] = val + elif opt == "-F": + search_filter["full_name"] = val + elif opt == "-O": + search_filter["organization"] = val + elif opt == "-r": + search_filter["role"] = val + elif opt == "-S": + search_filter["state"] = val + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) regex_patterns = [] - for (key, val) in search_filter.items(): - if isinstance(val, basestring) and val.find('|') != -1: + for key, val in search_filter.items(): + if isinstance(val, basestring) and val.find("|") != -1: regex_patterns.append(key) - (configuration, hits) = search_users(search_filter, conf_path, db_path, - verbose, regex_match=regex_patterns) + configuration, hits = search_users( + search_filter, conf_path, db_path, verbose, regex_match=regex_patterns + ) print("Matching users:") - for (uid, user_dict) in hits: + for uid, user_dict in hits: if only_fields: - field_list = ["%s" % user_dict.get(i, '') for i in only_fields] - print('%s' % ' : '.join(field_list)) + field_list = ["%s" % user_dict.get(i, "") for i in only_fields] + print("%s" % " : ".join(field_list)) else: - print('%s : %s' % (uid, user_dict)) + print("%s : %s" % (uid, user_dict)) diff --git a/mig/server/sftp_subsys.py b/mig/server/sftp_subsys.py index 68d4eeba9..d906cdde3 100755 --- a/mig/server/sftp_subsys.py +++ b/mig/server/sftp_subsys.py @@ -64,16 +64,14 @@ Inspired by https://gist.github.com/lonetwin/3b5982cf88c598c0e169 """ -from __future__ import absolute_import -from __future__ import print_function - -from builtins import object +from __future__ import absolute_import, print_function import io import os import sys import threading import time +from builtins import object try: from paramiko.server import ServerInterface @@ -95,8 +93,11 @@ from mig.server.grid_sftp import SimpleSftpServer as SftpServerImpl from mig.shared.conf import get_configuration_object from mig.shared.fileio import user_chroot_exceptions - from mig.shared.logger import daemon_logger, daemon_gdp_logger, \ - register_hangup_handler + from mig.shared.logger import ( + daemon_gdp_logger, + daemon_logger, + register_hangup_handler, + ) except ImportError: print("ERROR: the migrid modules must be in PYTHONPATH") sys.exit(1) @@ -126,7 +127,7 @@ def __init__(self, stdin, stdout): def send(self, data, flags=0): """Fake send""" - #logger.debug("IOSocketAdapter send: %s" % [data]) + # logger.debug("IOSocketAdapter send: %s" % [data]) self._stdout.flush() self._stdout.write(data) self._stdout.flush() @@ -135,12 +136,12 @@ def send(self, data, flags=0): def recv(self, bufsize, flags=0): """Fake recv""" data = self._stdin.read(bufsize) - #logger.debug("IOSocketAdapter recvd: %s" % [data]) + # logger.debug("IOSocketAdapter recvd: %s" % [data]) return data def close(self): """Fake close""" - #logger.debug("IOSocketAdapter close") + # logger.debug("IOSocketAdapter close") self._stdin.close() self._stdout.close() @@ -151,7 +152,7 @@ def settimeout(self, ignored): def get_name(self): """Used for paramiko logging""" # NOTE: we still need to set transport log explicitly - return 'sftp' + return "sftp" def get_transport(self): """Lazy transport init and getter""" @@ -160,42 +161,46 @@ def get_transport(self): return self._transport -if __name__ == '__main__': +if __name__ == "__main__": # We need to manualy extract MiG conf path since running from openssh - conf_path = os.path.join(os.path.dirname(__file__), 'MiGserver.conf') - os.putenv('MIG_CONF', conf_path) + conf_path = os.path.join(os.path.dirname(__file__), "MiGserver.conf") + os.putenv("MIG_CONF", conf_path) # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) log_level = configuration.loglevel # Use separate logger - logger = daemon_logger('sftp-subsys', configuration.user_sftp_subsys_log, - log_level) + logger = daemon_logger( + "sftp-subsys", configuration.user_sftp_subsys_log, log_level + ) configuration.logger = logger auth_logger = daemon_logger( - "sftp-subsys-auth", configuration.user_auth_log, log_level) + "sftp-subsys-auth", configuration.user_auth_log, log_level + ) configuration.auth_logger = auth_logger if configuration.site_enable_gdp: - gdp_logger = daemon_gdp_logger("sftp-subsys-gdp", - level=log_level) + gdp_logger = daemon_gdp_logger("sftp-subsys-gdp", level=log_level) configuration.gdp_logger = gdp_logger # Allow e.g. logrotate to force log re-open after rotates register_hangup_handler(configuration) pid = os.getpid() - logger.info('(%d) Basic sftp subsystem initialized' % pid) + logger.info("(%d) Basic sftp subsystem initialized" % pid) # IMPORTANT: for security reasons we only allow restricted launch # The login shell should NOT evaluate arbitrary user code from # profile or shell rc files and should preferably call this # script directly. More info in the module doc-string above. - fallback_shells = ['/bin/sh'] - login_shell = os.environ.get('SHELL', 'UNKNOWN') - if sys.argv[1:] == ['-c', __file__]: + fallback_shells = ["/bin/sh"] + login_shell = os.environ.get("SHELL", "UNKNOWN") + if sys.argv[1:] == ["-c", __file__]: login_shell = sys.argv[0] if login_shell in fallback_shells: - logger.warning("sftp subsystem not using direct launch but %s" % - login_shell) + logger.warning( + "sftp subsystem not using direct launch but %s" % login_shell + ) elif login_shell != __file__: - logger.error("sftp subsystem launched with illegal/unsafe shell: %s" - % login_shell) + logger.error( + "sftp subsystem launched with illegal/unsafe shell: %s" + % login_shell + ) sys.exit(1) # Lookup chroot exceptions once and for all @@ -204,55 +209,56 @@ def get_transport(self): # in acceptable_chmod helper. chmod_exceptions = [] configuration.daemon_conf = { - 'root_dir': os.path.abspath(configuration.user_home), - 'chroot_exceptions': chroot_exceptions, - 'chmod_exceptions': chmod_exceptions, - 'allow_password': 'password' in configuration.user_sftp_auth, - 'allow_digest': False, - 'allow_publickey': 'publickey' in configuration.user_sftp_auth, - 'user_alias': configuration.user_sftp_alias, + "root_dir": os.path.abspath(configuration.user_home), + "chroot_exceptions": chroot_exceptions, + "chmod_exceptions": chmod_exceptions, + "allow_password": "password" in configuration.user_sftp_auth, + "allow_digest": False, + "allow_publickey": "publickey" in configuration.user_sftp_auth, + "user_alias": configuration.user_sftp_alias, # Lock needed here due to threaded creds updates - 'creds_lock': threading.Lock(), - 'users': [], - 'jobs': [], - 'shares': [], - 'jupyter_mounts': [], - 'login_map': {}, - 'hash_cache': {}, - 'time_stamp': 0, - 'logger': logger, - 'stop_running': threading.Event(), + "creds_lock": threading.Lock(), + "users": [], + "jobs": [], + "shares": [], + "jupyter_mounts": [], + "login_map": {}, + "hash_cache": {}, + "time_stamp": 0, + "logger": logger, + "stop_running": threading.Event(), } try: - logger.debug('Create socket adaptor') + logger.debug("Create socket adaptor") socket_adapter = IOSocketAdapter(sys.stdin, sys.stdout) - logger.debug('Create server interface') + logger.debug("Create server interface") server_if = ServerInterface() - logger.debug('Create sftp server') + logger.debug("Create sftp server") # Pass helper vars directly on class to avoid API tampering SftpServerImpl.configuration = configuration SftpServerImpl.conf = configuration.daemon_conf SftpServerImpl.logger = logger - sftp_server = SFTPServer(socket_adapter, 'sftp', server=server_if, - sftp_si=SftpServerImpl) + sftp_server = SFTPServer( + socket_adapter, "sftp", server=server_if, sftp_si=SftpServerImpl + ) # IMPORTANT: make sure spawned client handler thread uses main log - socket_adapter.get_transport().set_log_channel('sftp-subsys') - logger.info('(%s) Start sftp subsys server' % pid) + socket_adapter.get_transport().set_log_channel("sftp-subsys") + logger.info("(%s) Start sftp subsys server" % pid) # NOTE: we explicitly loop and join thread to act on hangup signal try: sftp_server.setDaemon(False) sftp_server.start() except Exception as exc: logger.error("(%d) Crashed with exception: %s" % (pid, exc)) - configuration.daemon_conf['stop_running'].set() + configuration.daemon_conf["stop_running"].set() - logger.info('(%s) Handling client' % pid) + logger.info("(%s) Handling client" % pid) while True: try: - if configuration.daemon_conf['stop_running'].is_set(): + if configuration.daemon_conf["stop_running"].is_set(): # TODO: should we terminate sftp_server here? - logger.info('(%d) Join sftp subsys server worker' % pid) + logger.info("(%d) Join sftp subsys server worker" % pid) sftp_server.join() break else: @@ -262,16 +268,17 @@ def get_transport(self): if not sftp_server.is_alive(): # logger.debug( # '(%d) Joined sftp subsys server worker' % pid) - configuration.daemon_conf['stop_running'].set() + configuration.daemon_conf["stop_running"].set() break except KeyboardInterrupt: logger.info("(%d) Received user interrupt" % pid) - configuration.daemon_conf['stop_running'].set() + configuration.daemon_conf["stop_running"].set() except Exception as exc: logger.error("(%d) Crashed with exception: %s" % (pid, exc)) - configuration.daemon_conf['stop_running'].set() - logger.info('(%d) Finished sftp subsys server' % pid) + configuration.daemon_conf["stop_running"].set() + logger.info("(%d) Finished sftp subsys server" % pid) except Exception as exc: - logger.error('(%d) Failed to run sftp subsys server: %s' % (pid, exc)) + logger.error("(%d) Failed to run sftp subsys server: %s" % (pid, exc)) import traceback + logger.error(traceback.format_exc()) diff --git a/mig/server/sftpfailinfo.py b/mig/server/sftpfailinfo.py index ffd829894..e0ef4c99d 100755 --- a/mig/server/sftpfailinfo.py +++ b/mig/server/sftpfailinfo.py @@ -30,10 +30,8 @@ """Grep for sftp negotiation in sftp.log and translate source IP to FQDN""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function -from past.builtins import cmp import getopt import multiprocessing import os @@ -41,11 +39,13 @@ import socket import sys +from past.builtins import cmp + from mig.shared.conf import get_configuration_object from mig.shared.useradm import init_user_adm -def usage(name='sftpfailinfo.py'): +def usage(name="sftpfailinfo.py"): """Usage help""" print("""%(doc)s @@ -58,7 +58,7 @@ def usage(name='sftpfailinfo.py'): -v Verbose output -x TRUSTED_IP Trust IPs starting with this prefix (multiple allowed) -X TRUSTED_DOMAIN Trust FQDNs ending with this suffix (multiple allowed) -""" % {'doc': __doc__, 'name': name}) +""" % {"doc": __doc__, "name": name}) def dns_lookup(ip_addr): @@ -72,40 +72,40 @@ def dns_lookup(ip_addr): return (ip_addr, fqdn) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None # Never blacklist localhost IPs - trust_ip_list = ['127.0.'] + trust_ip_list = ["127.0."] # NOTE: 123.31.32.0/19 in Vietnam maps to 'localhost' - don't trust DNS trust_fqdn_list = [] verbose = False - opt_args = 'c:hvx:X:' + opt_args = "c:hvx:X:" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-c': + for opt, val in opts: + if opt == "-c": conf_path = val - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-v': + elif opt == "-v": verbose = True - elif opt == '-x': + elif opt == "-x": trust_ip_list.append(val.strip()) - elif opt == '-X': + elif opt == "-X": trust_fqdn_list.append(val.strip()) else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) sys.exit(1) if conf_path: - os.environ['MIG_CONF'] = conf_path + os.environ["MIG_CONF"] = conf_path configuration = get_configuration_object() matches = [] extract_pattern = r"(.+) WARNING client negotiation errors for " @@ -117,7 +117,7 @@ def dns_lookup(ip_addr): print("Searching for SFTP negotiation errors in %s" % sftp_log) log_fd = open(sftp_log) for line in log_fd: - if line.find('WARNING client negotiation errors ') != -1: + if line.find("WARNING client negotiation errors ") != -1: matches.append(line) log_fd.close() if verbose: @@ -128,11 +128,11 @@ def dns_lookup(ip_addr): if match: stamp, source_ip, source_port, err_cond = match.group(1, 2, 3, 4) if not source_ip in ip_fail_map: - ip_fail_map[source_ip] = {'source_ip': source_ip} + ip_fail_map[source_ip] = {"source_ip": source_ip} if not err_cond in ip_fail_map[source_ip]: ip_fail_map[source_ip][err_cond] = 0 ip_fail_map[source_ip][err_cond] += 1 - ip_fail_map[source_ip]['last'] = stamp + ip_fail_map[source_ip]["last"] = stamp if not ip_fail_map: if verbose: @@ -145,7 +145,7 @@ def dns_lookup(ip_addr): workers = multiprocessing.Pool(processes=64) rdns_results = workers.map(dns_lookup, list(ip_fail_map)) fqdn_fail_map = {} - for (source_ip, source_fqdn) in rdns_results: + for source_ip, source_fqdn in rdns_results: fqdn_fail_map[source_fqdn] = ip_fail_map[source_ip] print("") @@ -153,24 +153,25 @@ def dns_lookup(ip_addr): print("----------------------") sorted_hosts = list(fqdn_fail_map) # Try to sort in a more intuitive way where TLD is considered first - sorted_hosts.sort(cmp=lambda a, b: cmp(a.split(".")[::-1], - b.split(".")[::-1])) + sorted_hosts.sort( + cmp=lambda a, b: cmp(a.split(".")[::-1], b.split(".")[::-1]) + ) show_limit, offender_limit = 8, 32 for source_fqdn in sorted_hosts: err_map = fqdn_fail_map[source_fqdn] - source_ip = err_map['source_ip'] - last = err_map['last'] + source_ip = err_map["source_ip"] + last = err_map["last"] host_stats = "%s (%s): " % (source_fqdn, source_ip) host_errs = [] total = 0 - for (err_cond, err_count) in err_map.items(): - if err_cond in ['source_ip', 'last']: + for err_cond, err_count in err_map.items(): + if err_cond in ["source_ip", "last"]: continue host_errs.append("%s: %d" % (err_cond, err_count)) total += err_count - host_stats += ' , '.join(host_errs) - host_stats += ' , total: %d' % total - host_stats += ' , last: %s' % last + host_stats += " , ".join(host_errs) + host_stats += " , total: %d" % total + host_stats += " , last: %s" % last # Only display repeated offenders and honor trust trust = False for trust_prefix in trust_ip_list: @@ -189,6 +190,8 @@ def dns_lookup(ip_addr): if total > offender_limit: print(" * You may want to verify origin and block if fishy:") print("\twhois %(source_ip)s|grep 'descr:'" % err_map) - print("\tsudo iptables -I INPUT -s %(source_ip)s/32 -j DROP" - % err_map) + print( + "\tsudo iptables -I INPUT -s %(source_ip)s/32 -j DROP" + % err_map + ) print("") diff --git a/mig/server/usagestats.py b/mig/server/usagestats.py index 04e938df4..87534ec9d 100755 --- a/mig/server/usagestats.py +++ b/mig/server/usagestats.py @@ -27,8 +27,7 @@ """Show basic stats about site users and storage use""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import getopt import os @@ -39,14 +38,14 @@ from mig.shared.defaults import freeze_meta_filename, keyword_auto from mig.shared.fileio import unpickle, walk from mig.shared.notification import send_email -from mig.shared.safeeval import subprocess_popen, subprocess_pipe +from mig.shared.safeeval import subprocess_pipe, subprocess_popen from mig.shared.serial import dump -from mig.shared.useradm import init_user_adm, search_users, default_search +from mig.shared.useradm import default_search, init_user_adm, search_users -valid_output_formats = ['csv', 'txt', 'pickle', 'json', 'yaml'] +valid_output_formats = ["csv", "txt", "pickle", "json", "yaml"] -def usage(name='usagestats.py'): +def usage(name="usagestats.py"): """Usage help""" print("""Collect site stats based on MiG user database and file system. @@ -64,46 +63,57 @@ def usage(name='usagestats.py'): -t FS_TYPE Limit disk stats to mounts of given FS_TYPE -v Verbose output -q Quiet mode e.g. for cron use -""" % {'name': name}) +""" % {"name": name}) def compact_stats(configuration, stats, sep): """Helper to flatten stats for use in txt and csv output""" fill = {} - fill['sep'] = sep - fill['disk_use'] = '\n'.join([sep.join(i) for i in stats['disk']['use']]) - fill['disk_mounts'] = '\n'.join([sep.join(i) - for i in stats['disk']['mounts']]) - fill['totals_all_users'] = stats['totals']['all_users'] - fill['totals_active_users'] = stats['totals']['active_users'] - fill['totals_vgrids'] = stats['totals']['vgrids'] - fill['totals_archives'] = stats['totals']['archives'] - fill['weekly_register_users'] = stats['weekly']['register_users'] - fill['weekly_expire_users'] = stats['weekly']['expire_users'] - fill['weekly_vgrids'] = stats['weekly']['vgrids'] - fill['weekly_archives'] = stats['weekly']['archives'] - - fill['users_by_org'] = '' - org_list = list(stats['org_counts']['all_users']) + fill["sep"] = sep + fill["disk_use"] = "\n".join([sep.join(i) for i in stats["disk"]["use"]]) + fill["disk_mounts"] = "\n".join( + [sep.join(i) for i in stats["disk"]["mounts"]] + ) + fill["totals_all_users"] = stats["totals"]["all_users"] + fill["totals_active_users"] = stats["totals"]["active_users"] + fill["totals_vgrids"] = stats["totals"]["vgrids"] + fill["totals_archives"] = stats["totals"]["archives"] + fill["weekly_register_users"] = stats["weekly"]["register_users"] + fill["weekly_expire_users"] = stats["weekly"]["expire_users"] + fill["weekly_vgrids"] = stats["weekly"]["vgrids"] + fill["weekly_archives"] = stats["weekly"]["archives"] + + fill["users_by_org"] = "" + org_list = list(stats["org_counts"]["all_users"]) org_list.sort() for org in org_list: - total_cnt = stats['org_counts']['all_users'][org] - active_cnt = stats['org_counts']['active_users'].get(org, 0) - fill['users_by_org'] += '%d%s%d%s%s\n' % ( - total_cnt, sep, active_cnt, sep, org) - - fill['users_by_domain'] = '' - domain_list = list(stats['domain_counts']['all_users']) + total_cnt = stats["org_counts"]["all_users"][org] + active_cnt = stats["org_counts"]["active_users"].get(org, 0) + fill["users_by_org"] += "%d%s%d%s%s\n" % ( + total_cnt, + sep, + active_cnt, + sep, + org, + ) + + fill["users_by_domain"] = "" + domain_list = list(stats["domain_counts"]["all_users"]) domain_list.sort() for domain in domain_list: - total_cnt = stats['domain_counts']['all_users'][domain] - active_cnt = stats['domain_counts']['active_users'].get(domain, 0) - fill['users_by_domain'] += '%d%s%d%s%s\n' % ( - total_cnt, sep, active_cnt, sep, domain) + total_cnt = stats["domain_counts"]["all_users"][domain] + active_cnt = stats["domain_counts"]["active_users"].get(domain, 0) + fill["users_by_domain"] += "%d%s%d%s%s\n" % ( + total_cnt, + sep, + active_cnt, + sep, + domain, + ) return fill -def format_txt(configuration, stats, sep='\t'): +def format_txt(configuration, stats, sep="\t"): """Format stats for plain text output""" fill = compact_stats(configuration, stats, sep) txt = """=== Disk Use === @@ -154,7 +164,7 @@ def format_txt(configuration, stats, sep='\t'): return txt % fill -def format_csv(configuration, stats, sep=';'): +def format_csv(configuration, stats, sep=";"): """Format stats for plain text output""" fill = compact_stats(configuration, stats, sep) # TODO: improve csv format @@ -194,23 +204,23 @@ def write_sitestats(configuration, stats, path_prefix, output_format): for ext in output_format: dst_path = "%s.%s" % (path_prefix, ext) - if ext == 'csv': + if ext == "csv": out = format_csv(configuration, stats) with open(dst_path, "w") as fh: fh.write(out) - elif ext == 'txt': + elif ext == "txt": out = format_txt(configuration, stats) with open(dst_path, "w") as fh: fh.write(out) - elif ext in ['json', 'yaml', 'pickle']: + elif ext in ["json", "yaml", "pickle"]: dump(stats, dst_path, serializer=ext) else: return False return True -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +if "__main__" == __name__: + args, app_dir, db_path = init_user_adm() conf_path = None only_fs_types = [] expire = None @@ -221,77 +231,88 @@ def write_sitestats(configuration, stats, path_prefix, output_format): output_formats = [] search_filter = default_search() expire_before, expire_after = None, None - opt_args = 'a:b:c:d:fho:qs:t:u:v' + opt_args = "a:b:c:d:fho:qs:t:u:v" try: - (opts, args) = getopt.getopt(args, opt_args) + opts, args = getopt.getopt(args, opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-a': - search_filter['expire_after'] = int(val) - elif opt == '-b': - search_filter['expire_before'] = int(val) - elif opt == '-c': + for opt, val in opts: + if opt == "-a": + search_filter["expire_after"] = int(val) + elif opt == "-b": + search_filter["expire_before"] = int(val) + elif opt == "-c": conf_path = val - elif opt == '-d': + elif opt == "-d": db_path = val - elif opt == '-f': + elif opt == "-f": force = True - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-o': + elif opt == "-o": for ext in val.split(): if ext in valid_output_formats: output_formats.append(ext) else: print("Error: unsupported output format: %s" % ext) - elif opt == '-q': + elif opt == "-q": quiet = True verbose = False - elif opt == '-s': + elif opt == "-s": sitestats_home = val - elif opt == '-t': + elif opt == "-t": only_fs_types += val.split() - elif opt == '-v': + elif opt == "-v": verbose = True quiet = False else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) sys.exit(1) if conf_path and not os.path.isfile(conf_path): - print('Failed to read configuration file: %s' % conf_path) + print("Failed to read configuration file: %s" % conf_path) sys.exit(1) - (configuration, all_hits) = search_users( - default_search(), conf_path, db_path) + configuration, all_hits = search_users(default_search(), conf_path, db_path) logger = configuration.logger cmd_env = os.environ now = time.time() - site_stats = {'created': now, 'disk': {'use': [], 'mounts': []}, - 'totals': {'all_users': 0, 'active_users': 0, 'vgrids': 0, - 'archives': 0}, - 'weekly': {'all_users': 0, 'active_users': 0, 'vgrids': 0, - 'archives': 0}, - 'org_counts': {'all_users': {}, 'active_users': {}}, - 'domain_counts': {'all_users': {}, 'active_users': {}} - } + site_stats = { + "created": now, + "disk": {"use": [], "mounts": []}, + "totals": { + "all_users": 0, + "active_users": 0, + "vgrids": 0, + "archives": 0, + }, + "weekly": { + "all_users": 0, + "active_users": 0, + "vgrids": 0, + "archives": 0, + }, + "org_counts": {"all_users": {}, "active_users": {}}, + "domain_counts": {"all_users": {}, "active_users": {}}, + } if sitestats_home == keyword_auto: sitestats_home = configuration.sitestats_home sitestats_path = None if sitestats_home: - sitestats_path = os.path.join(sitestats_home, 'usagestats-%d' % now) + sitestats_path = os.path.join(sitestats_home, "usagestats-%d" % now) if not output_formats: - output_formats = ['json'] + output_formats = ["json"] if not quiet: - print("Writing collected site stats in %s.{%s}" % - (sitestats_path, ','.join(output_formats))) + print( + "Writing collected site stats in %s.{%s}" + % (sitestats_path, ",".join(output_formats)) + ) if not verbose and sitestats_path is None: print("Neither verbose nor writing site stats - boring!") @@ -299,29 +320,31 @@ def write_sitestats(configuration, stats, path_prefix, output_format): df_opts = [] # NOTE: df expects multiple file system types as individual options for fs_type in only_fs_types: - df_opts += ['-t', fs_type] + df_opts += ["-t", fs_type] # NOTE: we want utf8-encoded output as text str for concat below - proc = subprocess_popen(['/bin/df'] + df_opts, stdout=subprocess_pipe, - text=True, env=cmd_env) + proc = subprocess_popen( + ["/bin/df"] + df_opts, stdout=subprocess_pipe, text=True, env=cmd_env + ) proc.wait() for line in proc.stdout.readlines(): - site_stats['disk']['use'].append(line.strip().split()) + site_stats["disk"]["use"].append(line.strip().split()) if verbose: print("=== Disk Use ===") - print('\n'.join(['\t'.join(i) for i in site_stats['disk']['use']])) + print("\n".join(["\t".join(i) for i in site_stats["disk"]["use"]])) # NOTE: mount expects multiple file system types as single comma-sep arg mount_opts = [] - mount_opts += ['-t', ','.join(only_fs_types)] + mount_opts += ["-t", ",".join(only_fs_types)] # NOTE: we want utf8-encoded output as text str for concat below - proc = subprocess_popen(['mount'] + mount_opts, stdout=subprocess_pipe, - text=True) + proc = subprocess_popen( + ["mount"] + mount_opts, stdout=subprocess_pipe, text=True + ) proc.wait() for line in proc.stdout.readlines(): - site_stats['disk']['mounts'].append(line.strip().split()) + site_stats["disk"]["mounts"].append(line.strip().split()) if verbose: print("=== Disk Mounts ===") - print('\n'.join(['\t'.join(i) for i in site_stats['disk']['mounts']])) + print("\n".join(["\t".join(i) for i in site_stats["disk"]["mounts"]])) print("""Where * vgrid_files_home is all vgrid shared folders * vgrid_private_base/vgrid_public_base are all vgrid web portals @@ -332,14 +355,14 @@ def write_sitestats(configuration, stats, path_prefix, output_format): all_uids = [uid for (uid, user_dict) in all_hits] # all_uids.sort() # print("DEBUG: %s" % all_uids) - site_stats['totals']['all_users'] = len(all_uids) + site_stats["totals"]["all_users"] = len(all_uids) if verbose: print("== Totals ==") print("=== Registered Local Users ===") - print(site_stats['totals']['all_users']) + print(site_stats["totals"]["all_users"]) - search_filter['expire_after'] = now - (_, active_hits) = search_users(search_filter, conf_path, db_path) + search_filter["expire_after"] = now + _, active_hits = search_users(search_filter, conf_path, db_path) # only_fields = ['distinguished_name'] # for (uid, user_dict) in active_hits: # if only_fields: @@ -348,32 +371,32 @@ def write_sitestats(configuration, stats, path_prefix, output_format): # print(uid) active_uids = [uid for (uid, user_dict) in active_hits] # active_uids.sort() - site_stats['totals']['active_users'] = len(active_uids) + site_stats["totals"]["active_users"] = len(active_uids) # print("DEBUG: %s" % active_uids) if verbose: print("=== Active Local Users ===") - print(site_stats['totals']['active_users']) + print(site_stats["totals"]["active_users"]) # Extract dirs recursively in root of vgrid_home - for (root, dirs, files) in walk(configuration.vgrid_home): + for root, dirs, files in walk(configuration.vgrid_home): # Filter dot dirs - for i in [j for j in dirs if j.startswith('.')]: + for i in [j for j in dirs if j.startswith(".")]: dirs.remove(i) if not dirs: continue # print("DEBUG: %s %s" % (root, dirs)) - site_stats['totals']['vgrids'] += len(dirs) + site_stats["totals"]["vgrids"] += len(dirs) if verbose: print("=== Registered VGrids ===") - print(site_stats['totals']['vgrids']) + print(site_stats["totals"]["vgrids"]) # Archives are in root and in user ID subdirs archive_count = 0 - for (root, dirs, files) in walk(configuration.freeze_home): + for root, dirs, files in walk(configuration.freeze_home): # Filter dot dirs - for i in [j for j in dirs if j.startswith('.')]: + for i in [j for j in dirs if j.startswith(".")]: dirs.remove(i) - sub_dir = root.replace(configuration.freeze_home, '').strip(os.sep) + sub_dir = root.replace(configuration.freeze_home, "").strip(os.sep) sub_parts = sub_dir.split(os.sep) if len(sub_parts) > 2: # Stop recursion @@ -381,17 +404,19 @@ def write_sitestats(configuration, stats, path_prefix, output_format): for i in dirs: dirs.remove(i) continue - if sub_parts[-1].find('archive-') != -1 and \ - freeze_meta_filename in files: + if ( + sub_parts[-1].find("archive-") != -1 + and freeze_meta_filename in files + ): # print("DEBUG: %s" % root) - site_stats['totals']['archives'] += 1 + site_stats["totals"]["archives"] += 1 # Stop recursion for i in dirs: dirs.remove(i) if verbose: print("=== Frozen Archives ===") - print(site_stats['totals']['archives']) + print(site_stats["totals"]["archives"]) # TODO: this is inaccurate as it does not apply for e.g. short term peers. # We can eventually switch to the new created and renewed user fields. @@ -399,34 +424,34 @@ def write_sitestats(configuration, stats, path_prefix, output_format): # We simply lookup all users with expire more than 358 days from now. nearly_a_year = now + (365 - 7) * 24 * 3600 search_filter = default_search() - search_filter['expire_after'] = nearly_a_year - (_, reg_hits) = search_users(search_filter, conf_path, db_path) + search_filter["expire_after"] = nearly_a_year + _, reg_hits = search_users(search_filter, conf_path, db_path) reg_uids = [uid for (uid, user_dict) in reg_hits] # reg_uids.sort() - site_stats['weekly']['register_users'] = len(reg_uids) + site_stats["weekly"]["register_users"] = len(reg_uids) if verbose: print("== This Week ==") print("=== Registered and Renewed Local Users ===") - print(site_stats['weekly']['register_users']) + print(site_stats["weekly"]["register_users"]) a_week_ago = now - 7 * 24 * 3600 search_filter = default_search() - search_filter['expire_after'] = a_week_ago - search_filter['expire_before'] = now - (_, exp_hits) = search_users(search_filter, conf_path, db_path) + search_filter["expire_after"] = a_week_ago + search_filter["expire_before"] = now + _, exp_hits = search_users(search_filter, conf_path, db_path) exp_uids = [uid for (uid, user_dict) in exp_hits] # exp_uids.sort() - site_stats['weekly']['expire_users'] = len(exp_uids) + site_stats["weekly"]["expire_users"] = len(exp_uids) if verbose: print("=== Recently expired Local Users ===") - print(site_stats['weekly']['expire_users']) + print(site_stats["weekly"]["expire_users"]) # NOTE: no maxdepth since nested vgrids are allowed, mindepth is known for target, however # NOTE: vgrid_home/X ctime also gets updated on any file changes in that dir - for (root, dirs, files) in walk(configuration.vgrid_home): + for root, dirs, files in walk(configuration.vgrid_home): # Filter dot dirs - for i in [j for j in dirs if j.startswith('.')]: + for i in [j for j in dirs if j.startswith(".")]: dirs.remove(i) if root == configuration.vgrid_home: continue @@ -434,92 +459,96 @@ def write_sitestats(configuration, stats, path_prefix, output_format): if root_mtime < a_week_ago: continue # print("DEBUG: %s" % root) - site_stats['weekly']['vgrids'] += 1 + site_stats["weekly"]["vgrids"] += 1 if verbose: print("=== Registered and Updated VGrids ===") - print(site_stats['weekly']['vgrids']) + print(site_stats["weekly"]["vgrids"]) # NOTE: meta.pck file never changes for archives # TODO: update to fit only new client_id location when migrated - for (root, dirs, files) in walk(configuration.freeze_home): + for root, dirs, files in walk(configuration.freeze_home): # Filter dot dirs - for i in [j for j in dirs if j.startswith('.')]: + for i in [j for j in dirs if j.startswith(".")]: dirs.remove(i) - sub_dir = root.replace(configuration.freeze_home, '').strip(os.sep) + sub_dir = root.replace(configuration.freeze_home, "").strip(os.sep) sub_parts = sub_dir.split(os.sep) if len(sub_parts) > 3: # Stop recursion for i in dirs: dirs.remove(i) continue - if sub_parts[-1].find('archive-') == -1 or \ - not freeze_meta_filename in files: + if ( + sub_parts[-1].find("archive-") == -1 + or not freeze_meta_filename in files + ): continue meta_path = os.path.join(root, freeze_meta_filename) meta_mtime = os.path.getmtime(meta_path) if meta_mtime > a_week_ago and meta_mtime < now: - site_stats['weekly']['archives'] += 1 + site_stats["weekly"]["archives"] += 1 # print("DEBUG: %s" % root) if verbose: print("=== Frozen Archives ===") - print(site_stats['weekly']['archives']) + print(site_stats["weekly"]["archives"]) # Organization and email domain stats # All users org_map = {} domain_map = {} - for (uid, user_dict) in all_hits: - org = user_dict.get('organization', 'UNKNOWN') + for uid, user_dict in all_hits: + org = user_dict.get("organization", "UNKNOWN") if org not in org_map: org_map[org] = 0 org_map[org] += 1 - email = user_dict.get('email', 'UNKNOWN') - domain = email.split('@', 1)[1].strip() + email = user_dict.get("email", "UNKNOWN") + domain = email.split("@", 1)[1].strip() if domain not in domain_map: domain_map[domain] = 0 domain_map[domain] += 1 - site_stats['org_counts']['all_users'].update(org_map) - site_stats['domain_counts']['all_users'].update(domain_map) + site_stats["org_counts"]["all_users"].update(org_map) + site_stats["domain_counts"]["all_users"].update(domain_map) # Active users org_map = {} domain_map = {} - for (uid, user_dict) in active_hits: - org = user_dict.get('organization', 'UNKNOWN') + for uid, user_dict in active_hits: + org = user_dict.get("organization", "UNKNOWN") if org not in org_map: org_map[org] = 0 org_map[org] += 1 - email = user_dict.get('email', 'UNKNOWN') - domain = email.split('@', 1)[1].strip() + email = user_dict.get("email", "UNKNOWN") + domain = email.split("@", 1)[1].strip() if domain not in domain_map: domain_map[domain] = 0 domain_map[domain] += 1 - site_stats['org_counts']['active_users'].update(org_map) - site_stats['domain_counts']['active_users'].update(domain_map) + site_stats["org_counts"]["active_users"].update(org_map) + site_stats["domain_counts"]["active_users"].update(domain_map) if verbose: print("== User Distribution ==") print("=== By Organization ===") - org_list = list(site_stats['org_counts']['all_users']) + org_list = list(site_stats["org_counts"]["all_users"]) org_list.sort() for org in org_list: - total_cnt = site_stats['org_counts']['all_users'][org] - active_cnt = site_stats['org_counts']['active_users'].get(org, 0) - print('%d\t%d\t%s' % (total_cnt, active_cnt, org)) + total_cnt = site_stats["org_counts"]["all_users"][org] + active_cnt = site_stats["org_counts"]["active_users"].get(org, 0) + print("%d\t%d\t%s" % (total_cnt, active_cnt, org)) print("=== By Email Domain ===") - domain_list = list(site_stats['domain_counts']['all_users']) + domain_list = list(site_stats["domain_counts"]["all_users"]) domain_list.sort() for domain in domain_list: - total_cnt = site_stats['domain_counts']['all_users'][domain] - active_cnt = site_stats['domain_counts']['active_users'].get( - domain, 0) - print('%d\t%d\t%s' % (total_cnt, active_cnt, domain)) - - if sitestats_path and not write_sitestats(configuration, site_stats, - sitestats_path, output_formats): + total_cnt = site_stats["domain_counts"]["all_users"][domain] + active_cnt = site_stats["domain_counts"]["active_users"].get( + domain, 0 + ) + print("%d\t%d\t%s" % (total_cnt, active_cnt, domain)) + + if sitestats_path and not write_sitestats( + configuration, site_stats, sitestats_path, output_formats + ): print("Error: writing site stats to %s failed!" % sitestats_path) sys.exit(0) diff --git a/mig/server/verifyarchives.py b/mig/server/verifyarchives.py index c0c95bce0..f185219e7 100755 --- a/mig/server/verifyarchives.py +++ b/mig/server/verifyarchives.py @@ -27,8 +27,7 @@ """Verify Archive intergrity by comparing archive cache with actual contents""" -from __future__ import print_function -from __future__ import absolute_import +from __future__ import absolute_import, print_function import fnmatch import getopt @@ -37,10 +36,17 @@ import time from mig.shared.base import client_dir_id, distinguished_name_to_user -from mig.shared.defaults import freeze_meta_filename, freeze_lock_filename, \ - public_archive_index, public_archive_files, public_archive_doi, \ - keyword_pending, keyword_final, keyword_any -from mig.shared.freezefunctions import sorted_hash_algos, checksum_file +from mig.shared.defaults import ( + freeze_lock_filename, + freeze_meta_filename, + keyword_any, + keyword_final, + keyword_pending, + public_archive_doi, + public_archive_files, + public_archive_index, +) +from mig.shared.freezefunctions import checksum_file, sorted_hash_algos from mig.shared.serial import load @@ -50,11 +56,16 @@ def fuzzy_match(i, j, offset=2.0): Useful for comparing e.g. file timestamps with minor fluctuations due to I/O times and rounding. """ - return (i - offset < j and j < i + offset) + return i - offset < j and j < i + offset -def check_archive_integrity(configuration, user_id, freeze_path, - required_state=keyword_any, verbose=False): +def check_archive_integrity( + configuration, + user_id, + freeze_path, + required_state=keyword_any, + verbose=False, +): """Inspect Archives in freeze_path and compare contents to pickled cache. The cache is a list with one dictionary per file using the format: {'sha512sum': '...', 'name': 'relpath/to/file.ext', @@ -69,15 +80,22 @@ def check_archive_integrity(configuration, user_id, freeze_path, print("Compare cache and contents for %s" % freeze_path) cache_path = "%s.cache" % freeze_path meta_path = os.path.join(freeze_path, freeze_meta_filename) - ignore_files = [freeze_lock_filename, freeze_meta_filename, '%s.lock' % - freeze_meta_filename, public_archive_index, - public_archive_files, public_archive_doi] + ignore_files = [ + freeze_lock_filename, + freeze_meta_filename, + "%s.lock" % freeze_meta_filename, + public_archive_index, + public_archive_files, + public_archive_doi, + ] # NOTE: if archive has no actual files it has no cache file either if not os.path.exists(cache_path): archive_list = os.listdir(freeze_path) if [i for i in archive_list if not i in ignore_files]: - print("Archive %s has data content but no file cache in %s" % - (freeze_path, cache_path)) + print( + "Archive %s has data content but no file cache in %s" + % (freeze_path, cache_path) + ) return False else: return True @@ -86,69 +104,92 @@ def check_archive_integrity(configuration, user_id, freeze_path, cache = load(cache_path) meta = load(meta_path) except Exception as exc: - print("Could not open archive helpers %s and %s for verification: %s" % - (cache_path, meta_path, exc)) + print( + "Could not open archive helpers %s and %s for verification: %s" + % (cache_path, meta_path, exc) + ) return False - meta_state = meta.get('STATE', keyword_pending) + meta_state = meta.get("STATE", keyword_pending) if required_state != keyword_any and meta_state != required_state: - print("Archive in %s is in %r state but check demanded state %r" % - (freeze_path, meta_state, required_state)) + print( + "Archive in %s is in %r state but check demanded state %r" + % (freeze_path, meta_state, required_state) + ) return False for entry in cache: - if entry['name'] in ignore_files: + if entry["name"] in ignore_files: continue - archive_path = os.path.join(freeze_path, entry['name']) + archive_path = os.path.join(freeze_path, entry["name"]) try: archived_stat = os.stat(archive_path) archived_size = archived_stat.st_size archived_created = archived_stat.st_ctime archived_modified = archived_stat.st_mtime - if archived_size != entry['size']: + if archived_size != entry["size"]: if meta_state == keyword_final: - print("Archive entry %s has wrong size %d (expected %d)" % - (archive_path, archived_size, entry['size'])) + print( + "Archive entry %s has wrong size %d (expected %d)" + % (archive_path, archived_size, entry["size"]) + ) return False elif verbose: - print("ignore size mismatch on non-final %s" % - archive_path) + print("ignore size mismatch on non-final %s" % archive_path) # NOTE: we allow a minor time offset to accept various fs hiccups - elif not fuzzy_match(entry['timestamp'], archived_created) and \ - not fuzzy_match(entry['timestamp'], archived_modified) and \ - not fuzzy_match(entry.get('file_mtime', -1), archived_modified): + elif ( + not fuzzy_match(entry["timestamp"], archived_created) + and not fuzzy_match(entry["timestamp"], archived_modified) + and not fuzzy_match( + entry.get("file_mtime", -1), archived_modified + ) + ): if meta_state == keyword_final: - print("Archive entry %s has wrong timestamp %f / %f (expected %f, %s)" % - (archive_path, archived_created, archived_modified, - entry['timestamp'], archived_stat)) + print( + "Archive entry %s has wrong timestamp %f / %f (expected %f, %s)" + % ( + archive_path, + archived_created, + archived_modified, + entry["timestamp"], + archived_stat, + ) + ) chksum_verified = False for algo in sorted_hash_algos: - chksum = entry.get(algo, '') - if not chksum or ' ' in chksum: + chksum = entry.get(algo, "") + if not chksum or " " in chksum: continue - print("Checking that %s of %r matches %r" % - (algo, archive_path, chksum)) - verify_chksum = checksum_file(archive_path, algo, - max_chunks=-1) + print( + "Checking that %s of %r matches %r" + % (algo, archive_path, chksum) + ) + verify_chksum = checksum_file( + archive_path, algo, max_chunks=-1 + ) if verify_chksum == chksum: chksum_verified = True break if chksum_verified: - print("Verified that %s of %r matches %r" % - (algo, archive_path, chksum)) + print( + "Verified that %s of %r matches %r" + % (algo, archive_path, chksum) + ) else: return False elif verbose: - print("ignore ctime mismatch on non-final %s" % - archive_path) + print( + "ignore ctime mismatch on non-final %s" % archive_path + ) except Exception as exc: - print("Archive entry %s failed verification: %s" % - (archive_path, exc)) + print( + "Archive entry %s failed verification: %s" % (archive_path, exc) + ) return False if verbose: print("Archive entry %s passed verification" % archive_path) return True -def usage(name='verifyarchives.py'): +def usage(name="verifyarchives.py"): """Usage help""" print("""Verify Archive integrity using cache and actual contents. @@ -163,109 +204,122 @@ def usage(name='verifyarchives.py'): -n ARCHIVE_NAME Filter to specific Archive name(s) (pattern) -s REQUIRED_STATE Fail if Archive is not in REQUIRED_STATE (default is ANY) -v Verbose output -""" % {'name': name}) +""" % {"name": name}) -if '__main__' == __name__: +if "__main__" == __name__: conf_path = None verbose = False - opt_args = 'A:B:c:hI:n:s:v' + opt_args = "A:B:c:hI:n:s:v" now = int(time.time()) created_after, created_before = 0, now - distinguished_name = '*' - archive_name = '*' + distinguished_name = "*" + archive_name = "*" required_state = keyword_any try: - (opts, args) = getopt.getopt(sys.argv[1:], opt_args) + opts, args = getopt.getopt(sys.argv[1:], opt_args) except getopt.GetoptError as err: - print('Error: ', err.msg) + print("Error: ", err.msg) usage() sys.exit(1) - for (opt, val) in opts: - if opt == '-A': + for opt, val in opts: + if opt == "-A": after = now - if val.startswith('+'): + if val.startswith("+"): after += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): after -= int(val[1:]) else: after = int(val) created_after = after - elif opt == '-B': + elif opt == "-B": before = now - if val.startswith('+'): + if val.startswith("+"): before += int(val[1:]) - elif val.startswith('-'): + elif val.startswith("-"): before -= int(val[1:]) else: before = int(val) created_before = before - elif opt == '-c': + elif opt == "-c": conf_path = val - elif opt == '-h': + elif opt == "-h": usage() sys.exit(0) - elif opt == '-I': + elif opt == "-I": distinguished_name = val - elif opt == '-n': + elif opt == "-n": archive_name = val - elif opt == '-s': + elif opt == "-s": required_state = val - elif opt == '-v': + elif opt == "-v": verbose = True else: - print('Error: %s not supported!' % opt) + print("Error: %s not supported!" % opt) usage() sys.exit(0) archive_hits = {} archive_fails = 0 from mig.shared.conf import get_configuration_object + configuration = get_configuration_object() - print("searching for Archives with creation stamp between %d and %d" % - (created_after, created_before)) + print( + "searching for Archives with creation stamp between %d and %d" + % (created_after, created_before) + ) for user_dir in os.listdir(configuration.freeze_home): base_path = os.path.join(configuration.freeze_home, user_dir) # Skip non-dirs and dirs not matching user IDs - if not os.path.isdir(base_path) or user_dir.find('+') == -1: + if not os.path.isdir(base_path) or user_dir.find("+") == -1: continue user_id = client_dir_id(user_dir) user_dict = distinguished_name_to_user(user_id) if not fnmatch.fnmatch(user_id, distinguished_name): if verbose: - print("skip Archives for %s not matching owner pattern %s" % - (user_id, distinguished_name)) + print( + "skip Archives for %s not matching owner pattern %s" + % (user_id, distinguished_name) + ) continue for freeze_name in os.listdir(base_path): # NOTE: tempfile increased random part from 6 to 8 chars in py3 - if not fnmatch.fnmatch(freeze_name, "archive-??????") and \ - not fnmatch.fnmatch(freeze_name, "archive-????????"): + if not fnmatch.fnmatch( + freeze_name, "archive-??????" + ) and not fnmatch.fnmatch(freeze_name, "archive-????????"): continue if not fnmatch.fnmatch(freeze_name, archive_name): if verbose: - print("filter Archive %s not matching name pattern %s" % - (freeze_name, archive_name)) + print( + "filter Archive %s not matching name pattern %s" + % (freeze_name, archive_name) + ) continue freeze_path = os.path.join(base_path, freeze_name) created_time = int(round(os.path.getctime(freeze_path))) if created_time < created_after or created_time > created_before: if verbose: - print("skip Archive %s outside creation window %d - %d" % - (freeze_name, created_after, created_before)) + print( + "skip Archive %s outside creation window %d - %d" + % (freeze_name, created_after, created_before) + ) continue elif verbose: - print("found %s for %s from %d to verify" % - (freeze_name, user_id, created_time)) + print( + "found %s for %s from %d to verify" + % (freeze_name, user_id, created_time) + ) archive_hits[user_id] = archive_hits.get(user_id, []) archive_hits[user_id].append(freeze_path) print("Archive integrity checks:") - for (user_id, archive_list) in archive_hits.items(): + for user_id, archive_list in archive_hits.items(): for freeze_path in archive_list: verified = check_archive_integrity( - configuration, user_id, freeze_path, required_state, verbose) + configuration, user_id, freeze_path, required_state, verbose + ) if verified: print("%s [PASS]" % freeze_path) else: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..d65ee25b8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +# Project wide settings including tooling. + +# Linting tools +# TODO: fix CodeQL errors in the symlinked files and disable exclude+skip here + +[tool.black] +line-length = 80 +# NOTE: the following regex matches us be kept in sync with isort skip +extend-exclude = ''' + tests/data/ + | tests/fixture/ +''' +exclude = ''' + bin/checkconf.py + | bin/createresource.py + | bin/notifypassword.py + | sbin/grid_ftps.py + | sbin/grid_openid.py + | sbin/grid_sftp.py + | sbin/grid_webdavs.py +''' + +[tool.isort] +profile = "black" +line_length = 80 +# NOTE: the following paths must be kept in sync with black extend-exclude +skip = [ + "bin/notifypassword.py", + "bin/checkconf.py", + "bin/createresource.py", + "sbin/grid_ftps.py", + "sbin/grid_openid.py", + "sbin/grid_sftp.py", + "sbin/grid_webdavs.py" +] +skip_glob = [".git/*", "tests/data/*", "tests/fixture/*"] diff --git a/sbin/grid_accounting.py b/sbin/grid_accounting.py index c0cf7a84d..7f2966872 100755 --- a/sbin/grid_accounting.py +++ b/sbin/grid_accounting.py @@ -37,20 +37,25 @@ import traceback from mig.lib.accounting import update_accounting -from mig.lib.daemon import check_run, check_stop, interruptible_sleep, \ - register_run_handler, register_stop_handler, reset_run, stop_running +from mig.lib.daemon import ( + check_run, + check_stop, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + stop_running, +) from mig.shared.conf import get_configuration_object from mig.shared.logger import daemon_logger, register_hangup_handler if __name__ == "__main__": - print( - """This is the MiG accounting daemon that collect storage accounting + print("""This is the MiG accounting daemon that collect storage accounting information for users and their associated vgrids, archives and peers. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -""" - ) +""") # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) @@ -60,9 +65,9 @@ # Use separate logger - logger = daemon_logger("accounting", - configuration.user_accounting_log, - log_level) + logger = daemon_logger( + "accounting", configuration.user_accounting_log, log_level + ) configuration.logger = logger # Check if accounting is enabled @@ -70,9 +75,7 @@ if not configuration.site_enable_accounting: msg = "Accounting support is disabled in configuration!" logger.error(msg) - print("%s ERROR: %s" - % (datetime.datetime.now(), msg), - file=sys.stderr) + print("%s ERROR: %s" % (datetime.datetime.now(), msg), file=sys.stderr) sys.exit(1) # Allow e.g. logrotate to force log re-open after rotates @@ -86,8 +89,10 @@ throttle_secs = float(configuration.accounting_update_interval) main_pid = os.getpid() - msg = "(%s) Starting accounting daemon with throttle: %d secs" \ - % (main_pid, throttle_secs) + msg = "(%s) Starting accounting daemon with throttle: %d secs" % ( + main_pid, + throttle_secs, + ) logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) @@ -95,16 +100,20 @@ while not check_stop(): try: if throttle: - interruptible_sleep(configuration, throttle_secs, - (check_run, check_stop)) + interruptible_sleep( + configuration, throttle_secs, (check_run, check_stop) + ) reset_run() if check_stop(): break t1 = time.time() status = update_accounting(configuration, verbose=True) t2 = time.time() - msg = "(%s) Updated accounting in %d secs with status: %s" \ - % (os.getpid(), int(t2-t1), status) + msg = "(%s) Updated accounting in %d secs with status: %s" % ( + os.getpid(), + int(t2 - t1), + status, + ) logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) throttle = True @@ -112,18 +121,19 @@ stop_running() # NOTE: we can't be sure if SIGINT was sent to only main process # so we make sure to propagate to monitor child - msg = "(%s) Interrupt requested - shutdown" \ - % os.getpid() + msg = "(%s) Interrupt requested - shutdown" % os.getpid() logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) except Exception as exc: throttle = True - msg = "(%s) Caught unexpected exception:\n%s" \ - % (os.getpid(), traceback.format_exc()) + msg = "(%s) Caught unexpected exception:\n%s" % ( + os.getpid(), + traceback.format_exc(), + ) logger.error(msg) - print("%s ERROR: %s" - % (datetime.datetime.now(), msg), - file=sys.stderr) + print( + "%s ERROR: %s" % (datetime.datetime.now(), msg), file=sys.stderr + ) msg = "(%s) Accounting daemon shutting down" % main_pid logger.info(msg) diff --git a/sbin/grid_janitor.py b/sbin/grid_janitor.py index 0ce480053..b930fc19e 100755 --- a/sbin/grid_janitor.py +++ b/sbin/grid_janitor.py @@ -35,8 +35,15 @@ import sys import time -from mig.lib.daemon import check_run, check_stop, interruptible_sleep, \ - register_run_handler, register_stop_handler, reset_run, stop_running +from mig.lib.daemon import ( + check_run, + check_stop, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + stop_running, +) from mig.lib.janitor import handle_janitor_tasks from mig.shared.conf import get_configuration_object from mig.shared.logger import daemon_logger, register_hangup_handler @@ -45,7 +52,7 @@ SHORT_THROTTLE_SECS = 5.0 LONG_THROTTLE_SECS = 30.0 -(configuration, logger) = (None, None) +configuration, logger = (None, None) if __name__ == "__main__": @@ -58,8 +65,7 @@ # Use separate logger - logger = daemon_logger("janitor", configuration.user_janitor_log, - log_level) + logger = daemon_logger("janitor", configuration.user_janitor_log, log_level) configuration.logger = logger # Allow e.g. logrotate to force log re-open after rotates @@ -77,14 +83,12 @@ print(err_msg) sys.exit(1) - print( - """This is the MiG janitor daemon which cleans up stale state data, + print("""This is the MiG janitor daemon which cleans up stale state data, updates internal caches and prunes pending requests. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -""" - ) +""") main_pid = os.getpid() print("Starting janitor daemon - Ctrl-C to quit") @@ -98,15 +102,18 @@ now = time.time() if last_failed: # Throttle on general exception in main loop - interruptible_sleep(configuration, LONG_THROTTLE_SECS, - (check_run, check_stop)) + interruptible_sleep( + configuration, LONG_THROTTLE_SECS, (check_run, check_stop) + ) last_failed = False elif handle_janitor_tasks(configuration, now) <= 0: - interruptible_sleep(configuration, LONG_THROTTLE_SECS, - (check_run, check_stop)) + interruptible_sleep( + configuration, LONG_THROTTLE_SECS, (check_run, check_stop) + ) else: - interruptible_sleep(configuration, SHORT_THROTTLE_SECS, - (check_run, check_stop)) + interruptible_sleep( + configuration, SHORT_THROTTLE_SECS, (check_run, check_stop) + ) reset_run() except KeyboardInterrupt: stop_running() diff --git a/sbin/grid_quota.py b/sbin/grid_quota.py index f7cf13ab8..0deb007ae 100755 --- a/sbin/grid_quota.py +++ b/sbin/grid_quota.py @@ -29,28 +29,32 @@ from __future__ import absolute_import, print_function +import datetime import os import sys import time import traceback -import datetime -from mig.lib.daemon import check_run, check_stop, interruptible_sleep, \ - register_run_handler, register_stop_handler, reset_run, stop_running -from mig.lib.quota import update_quota, supported_quota_backends +from mig.lib.daemon import ( + check_run, + check_stop, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + stop_running, +) +from mig.lib.quota import supported_quota_backends, update_quota from mig.shared.conf import get_configuration_object from mig.shared.logger import daemon_logger, register_hangup_handler - if __name__ == "__main__": - print( - """This is the MiG quota daemon which collects storage quota + print("""This is the MiG quota daemon which collects storage quota information for users and vgrids. Set the MIG_CONF environment to the server configuration path unless it is available in mig/server/MiGserver.conf -""" - ) +""") # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) @@ -60,9 +64,7 @@ # Use separate logger - logger = daemon_logger("quota", - configuration.user_quota_log, - log_level) + logger = daemon_logger("quota", configuration.user_quota_log, log_level) configuration.logger = logger # Check if quota is enabled @@ -70,21 +72,18 @@ if not configuration.site_enable_quota: msg = "Quota support is disabled in configuration!" logger.error(msg) - print("%s ERROR: %s" - % (datetime.datetime.now(), msg), - file=sys.stderr) + print("%s ERROR: %s" % (datetime.datetime.now(), msg), file=sys.stderr) sys.exit(1) # Check quota backend if configuration.quota_backend not in supported_quota_backends: - msg = "Quota backend: %s not in supported backends: %s" \ - % (configuration.quota_backend, - ", ".join(supported_quota_backends)) + msg = "Quota backend: %s not in supported backends: %s" % ( + configuration.quota_backend, + ", ".join(supported_quota_backends), + ) logger.error(msg) - print("%s ERROR: %s" - % (datetime.datetime.now(), msg), - file=sys.stderr) + print("%s ERROR: %s" % (datetime.datetime.now(), msg), file=sys.stderr) sys.exit(1) # Allow e.g. logrotate to force log re-open after rotates @@ -98,8 +97,10 @@ throttle_secs = float(configuration.quota_update_interval) main_pid = os.getpid() - msg = "(%s) Starting quota daemon with throttle: %d secs" \ - % (main_pid, throttle_secs) + msg = "(%s) Starting quota daemon with throttle: %d secs" % ( + main_pid, + throttle_secs, + ) logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) @@ -107,16 +108,20 @@ while not check_stop(): try: if throttle: - interruptible_sleep(configuration, throttle_secs, - (check_run, check_stop)) + interruptible_sleep( + configuration, throttle_secs, (check_run, check_stop) + ) reset_run() if check_stop(): break t1 = time.time() status = update_quota(configuration) t2 = time.time() - msg = "(%s) Updated quota in %d secs with status: %s" \ - % (os.getpid(), int(t2-t1), status) + msg = "(%s) Updated quota in %d secs with status: %s" % ( + os.getpid(), + int(t2 - t1), + status, + ) logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) throttle = True @@ -124,18 +129,19 @@ stop_running() # NOTE: we can't be sure if SIGINT was sent to only main process # so we make sure to propagate to monitor child - msg = "(%s) Interrupt requested - shutdown" \ - % os.getpid() + msg = "(%s) Interrupt requested - shutdown" % os.getpid() logger.info(msg) print("%s %s" % (datetime.datetime.now(), msg)) except Exception as exc: throttle = True - msg = "(%s) Caught unexpected exception:\n%s" \ - % (os.getpid(), traceback.format_exc()) + msg = "(%s) Caught unexpected exception:\n%s" % ( + os.getpid(), + traceback.format_exc(), + ) logger.error(msg) - print("%s ERROR: %s" - % (datetime.datetime.now(), msg), - file=sys.stderr) + print( + "%s ERROR: %s" % (datetime.datetime.now(), msg), file=sys.stderr + ) msg = "(%s) Quota daemon shutting down" % main_pid logger.info(msg) diff --git a/tests/__init__.py b/tests/__init__.py index bcec2ab8a..d1e480f4d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,10 +1,14 @@ def _print_identity(): import os import sys - python_version_string = sys.version.split(' ')[0] - mig_env = os.environ.get('MIG_ENV', 'local') - print("running with MIG_ENV='%s' under Python %s" % - (mig_env, python_version_string)) + + python_version_string = sys.version.split(" ")[0] + mig_env = os.environ.get("MIG_ENV", "local") + print( + "running with MIG_ENV='%s' under Python %s" + % (mig_env, python_version_string) + ) print("") + _print_identity() diff --git a/tests/support/__init__.py b/tests/support/__init__.py index eb2c8019f..da75f1b7c 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -28,8 +28,6 @@ """Supporting functions for the unit test framework""" -from collections import defaultdict -from configparser import ConfigParser import difflib import errno import io @@ -40,34 +38,40 @@ import shutil import stat import sys +from collections import defaultdict +from configparser import ConfigParser from types import SimpleNamespace -from unittest import TestCase, main as testmain +from unittest import TestCase +from unittest import main as testmain +from tests.support._env import MIG_ENV, PY2 from tests.support.configsupp import FakeConfiguration from tests.support.fixturesupp import _PreparedFixture -from tests.support.suppconst import MIG_BASE, TEST_BASE, \ - TEST_DATA_DIR, TEST_OUTPUT_DIR, ENVHELP_OUTPUT_DIR +from tests.support.suppconst import ( + ENVHELP_OUTPUT_DIR, + MIG_BASE, + TEST_BASE, + TEST_DATA_DIR, + TEST_OUTPUT_DIR, +) from tests.support.usersupp import UserAssertMixin -from tests.support._env import MIG_ENV, PY2 - - # Provide access to a configuration file for the active environment. -if MIG_ENV in ('local', 'docker'): +if MIG_ENV in ("local", "docker"): # force local testconfig - _output_dir = os.path.join(MIG_BASE, 'envhelp/output') + _output_dir = os.path.join(MIG_BASE, "envhelp/output") _conf_dir_name = "testconfs-%s" % (MIG_ENV,) _conf_dir = os.path.join(_output_dir, _conf_dir_name) - _local_conf = os.path.join(_conf_dir, 'MiGserver.conf') - _config_file = os.getenv('MIG_CONF', None) + _local_conf = os.path.join(_conf_dir, "MiGserver.conf") + _config_file = os.getenv("MIG_CONF", None) if _config_file is None: - os.environ['MIG_CONF'] = _local_conf + os.environ["MIG_CONF"] = _local_conf # adjust the link through which confs are accessed to suit the environment - _conf_link = os.path.join(_output_dir, 'testconfs') + _conf_link = os.path.join(_output_dir, "testconfs") assert os.path.lexists(_conf_link) # it must already exist - os.remove(_conf_link) # blow it away + os.remove(_conf_link) # blow it away os.symlink(_conf_dir, _conf_link) # recreate it using the active MIG_BASE else: raise NotImplementedError() @@ -97,7 +101,6 @@ from tests.support.loggersupp import FakeLogger, FakeLoggerChecker from tests.support.serversupp import make_wrapped_server - # Basic global logging configuration for testing @@ -181,7 +184,7 @@ def tearDown(self): @classmethod def tearDownClass(cls): - if MIG_ENV == 'docker': + if MIG_ENV == "docker": # the permissions story wrt running inside docker containers is # such that we can end up with files from previous test runs left # around that might subsequently cause spurious permissions errors @@ -209,20 +212,24 @@ def _reset_logging(self, stream): @staticmethod def _make_configuration_instance(testcase, configuration_to_make): - if configuration_to_make == 'fakeconfig': + if configuration_to_make == "fakeconfig": return FakeConfiguration(logger=testcase.logger) - elif configuration_to_make == 'testconfig': + elif configuration_to_make == "testconfig": from mig.shared.conf import get_configuration_object - configuration = get_configuration_object(skip_log=True, - disable_auth_log=True) + + configuration = get_configuration_object( + skip_log=True, disable_auth_log=True + ) configuration.logger = testcase.logger return configuration else: raise AssertionError( - "MigTestCase: unknown configuration %r" % (configuration_to_make,)) + "MigTestCase: unknown configuration %r" + % (configuration_to_make,) + ) def _provide_configuration(self): - return 'unspecified' + return "unspecified" @property def configuration(self): @@ -233,14 +240,16 @@ def configuration(self): configuration_to_make = self._provide_configuration() - if configuration_to_make == 'unspecified': + if configuration_to_make == "unspecified": raise AssertionError( - "configuration access but testcase did not request it") + "configuration access but testcase did not request it" + ) configuration_instance = self._make_configuration_instance( - self, configuration_to_make) + self, configuration_to_make + ) - if configuration_to_make == 'testconfig': + if configuration_to_make == "testconfig": # use the paths defined by the loaded configuration to create # the directories which are expected to be present by the code os.mkdir(configuration_instance.certs_path) @@ -274,7 +283,8 @@ def assertDirEmpty(self, relative_path): """Make sure the supplied path is an empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert not entries, "directory is not empty" @@ -283,7 +293,8 @@ def assertDirNotEmpty(self, relative_path): """Make sure the supplied path is a non-empty directory""" path_kind = self.assertPathExists(relative_path) assert path_kind == "dir", "expected a directory but found %s" % ( - path_kind, ) + path_kind, + ) absolute_path = os.path.join(TEST_OUTPUT_DIR, relative_path) entries = os.listdir(absolute_path) assert entries, "directory is empty" @@ -291,28 +302,35 @@ def assertDirNotEmpty(self, relative_path): def assertFileContentIdentical(self, file_actual, file_expected): """Make sure file_actual and file_expected are identical""" - with io.open(file_actual) as f_actual, io.open(file_expected) as f_expected: + with io.open(file_actual) as f_actual, io.open( + file_expected + ) as f_expected: lhs = f_actual.readlines() rhs = f_expected.readlines() different_lines = list(difflib.unified_diff(rhs, lhs)) try: self.assertEqual(len(different_lines), 0) except AssertionError: - raise AssertionError("""differences found between files + raise AssertionError( + """differences found between files * %s * %s included: %s - """ % ( - os.path.relpath(file_expected, MIG_BASE), - os.path.relpath(file_actual, MIG_BASE), - ''.join(different_lines))) + """ + % ( + os.path.relpath(file_expected, MIG_BASE), + os.path.relpath(file_actual, MIG_BASE), + "".join(different_lines), + ) + ) def assertFileExists(self, relative_path): """Make sure relative_path exists and is a file""" path_kind = self.assertPathExists(relative_path) assert path_kind == "file", "expected a file but found %s" % ( - path_kind, ) + path_kind, + ) return os.path.join(TEST_OUTPUT_DIR, relative_path) def assertPathExists(self, relative_path): @@ -342,13 +360,14 @@ def assertPathWithin(self, path, start=None): """Make sure path is within start directory""" if not is_path_within(path, start=start): raise AssertionError( - "path %s is not within directory %s" % (path, start)) + "path %s is not within directory %s" % (path, start) + ) @staticmethod def pretty_display_path(absolute_path): assert os.path.isabs(absolute_path) relative_path = os.path.relpath(absolute_path, start=MIG_BASE) - assert not relative_path.startswith('..') + assert not relative_path.startswith("..") return relative_path @staticmethod @@ -359,7 +378,9 @@ def _provision_test_user(testcase, distinguished_name): Note that this method, along with a number of others, are defined in the user portion of the test support libraries. """ - return UserAssertMixin._provision_test_user(testcase, distinguished_name) + return UserAssertMixin._provision_test_user( + testcase, distinguished_name + ) def is_path_within(path, start=None, _msg=None): @@ -369,7 +390,7 @@ def is_path_within(path, start=None, _msg=None): relative = os.path.relpath(path, start=start) except: return False - return not relative.startswith('..') + return not relative.startswith("..") def ensure_dirs_exist(absolute_dir): @@ -400,7 +421,7 @@ def temppath(relative_path, test_case, ensure_dir=False): # failsafe path checking that supplied paths are rooted within valid paths is_tmp_path_within_safe_dir = False - for start in (ENVHELP_OUTPUT_DIR): + for start in ENVHELP_OUTPUT_DIR: is_tmp_path_within_safe_dir = is_path_within(tmp_path, start=start) if is_tmp_path_within_safe_dir: break @@ -413,7 +434,8 @@ def temppath(relative_path, test_case, ensure_dir=False): except OSError as oserr: if oserr.errno == errno.EEXIST: raise AssertionError( - "ABORT: use of unclean output path: %s" % tmp_path) + "ABORT: use of unclean output path: %s" % tmp_path + ) return tmp_path diff --git a/tests/support/_env.py b/tests/support/_env.py index 2c71386a4..d86bbd2d3 100644 --- a/tests/support/_env.py +++ b/tests/support/_env.py @@ -2,10 +2,10 @@ import sys # expose the configured environment as a constant -MIG_ENV = os.environ.get('MIG_ENV', 'local') +MIG_ENV = os.environ.get("MIG_ENV", "local") # force the chosen environment globally -os.environ['MIG_ENV'] = MIG_ENV +os.environ["MIG_ENV"] = MIG_ENV # expose a boolean indicating whether we are executing on Python 2 -PY2 = (sys.version_info[0] == 2) +PY2 = sys.version_info[0] == 2 diff --git a/tests/support/assertover.py b/tests/support/assertover.py index 52b445ab2..3e31b7b5b 100644 --- a/tests/support/assertover.py +++ b/tests/support/assertover.py @@ -29,11 +29,13 @@ class NoBlockError(AssertionError): """Decorate AssertionError for our own convenience""" + pass class NoCasesError(AssertionError): """Decorate AssertionError for our own convenience""" + pass @@ -76,10 +78,15 @@ def __exit__(self, exc_type, exc_value, traceback): if not any(self._attempts): return True - value_lines = ["- <%r> : %s" % (attempt[0], str(attempt[1])) for - attempt in self._attempts if attempt] - raise AssertionError("assertions raised for the following values:\n%s" - % '\n'.join(value_lines)) + value_lines = [ + "- <%r> : %s" % (attempt[0], str(attempt[1])) + for attempt in self._attempts + if attempt + ] + raise AssertionError( + "assertions raised for the following values:\n%s" + % "\n".join(value_lines) + ) def record_attempt(self, attempt_info): """Record the result of a test attempt""" @@ -89,7 +96,9 @@ def to_check_callable(self): def raise_unless_consulted(): if not self._consulted: raise AssertionError( - "no examiniation made of assertion of multiple values") + "no examiniation made of assertion of multiple values" + ) + return raise_unless_consulted def assert_success(self): @@ -103,4 +112,7 @@ def _execute_block(cls, block, block_value): block.__call__(block_value) return None except Exception as blockexc: - return (block_value, blockexc,) + return ( + block_value, + blockexc, + ) diff --git a/tests/support/configsupp.py b/tests/support/configsupp.py index 0846e465d..bb8511542 100644 --- a/tests/support/configsupp.py +++ b/tests/support/configsupp.py @@ -27,20 +27,21 @@ """Configuration related details within the test support library.""" -from tests.support.loggersupp import FakeLogger - from mig.shared.compat import SimpleNamespace -from mig.shared.configuration import \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, +) +from tests.support.loggersupp import FakeLogger def _ensure_only_configuration_keys(thedict): - """Check a dictionary contains only keys valid as Configuration properties. - """ + """Check a dictionary contains only keys valid as Configuration properties.""" unknown_keys = set(thedict.keys()) - set(_CONFIGURATION_ARGUMENTS) - assert len(unknown_keys) == 0, \ - "non-Configuration keys: %s" % (', '.join(unknown_keys),) + assert len(unknown_keys) == 0, "non-Configuration keys: %s" % ( + ", ".join(unknown_keys), + ) def _generate_namespace_kwargs(): @@ -49,7 +50,7 @@ def _generate_namespace_kwargs(): """ properties_and_defaults = dict(_CONFIGURATION_PROPERTIES) - properties_and_defaults['logger'] = None + properties_and_defaults["logger"] = None return properties_and_defaults diff --git a/tests/support/fixturesupp.py b/tests/support/fixturesupp.py index c418fc7a2..0172a3e43 100644 --- a/tests/support/fixturesupp.py +++ b/tests/support/fixturesupp.py @@ -27,12 +27,12 @@ """Fixture related details within the test support library.""" -from configparser import ConfigParser -from datetime import date, timedelta import json import os import pickle import shutil +from configparser import ConfigParser +from datetime import date, timedelta from time import mktime from types import SimpleNamespace @@ -51,30 +51,35 @@ def _fixturefile_loadrelative(fixture_name, fixture_format=None): assert fixture_format is not None, "fixture format must be specified" relative_path_with_ext = "%s.%s" % (fixture_name, fixture_format) tmp_path = os.path.join(TEST_FIXTURE_DIR, relative_path_with_ext) - assert os.path.isfile(tmp_path), \ - 'fixture named "%s" with format %s is not present: %s' % \ - (fixture_name, fixture_format, relative_path_with_ext) + assert os.path.isfile( + tmp_path + ), 'fixture named "%s" with format %s is not present: %s' % ( + fixture_name, + fixture_format, + relative_path_with_ext, + ) data = None - if fixture_format == 'binary': - with open(tmp_path, 'rb') as binfile: + if fixture_format == "binary": + with open(tmp_path, "rb") as binfile: data = binfile.read() - elif fixture_format == 'json': + elif fixture_format == "json": with open(tmp_path) as jsonfile: data = json.load(jsonfile, object_hook=_FixtureHint.object_hook) _hints_apply_from_instances_if_present(data) _hints_apply_from_fixture_ini_if_present(fixture_name, data) else: raise AssertionError( - "unsupported fixture format: %s" % (fixture_format,)) + "unsupported fixture format: %s" % (fixture_format,) + ) return data, tmp_path -def _fixturefile_normname(relative_path, prefix=''): +def _fixturefile_normname(relative_path, prefix=""): """Grab normname from relative_path and optionally add a path prefix""" - normname, _ = relative_path.split('--') + normname, _ = relative_path.split("--") if prefix: return os.path.join(prefix, normname) return normname @@ -96,6 +101,7 @@ def _fixturefile_normname(relative_path, prefix=''): # # + def _hints_apply_array_of_tuples(value, modifier): """ Convert list of lists such that its values are instead tuples. @@ -109,7 +115,7 @@ def _hints_apply_today_relative(value, modifier): Geneate a time value by applying a declared delta to today's date. """ - kind, delta = modifier.split('|') + kind, delta = modifier.split("|") if kind == "days": time_delta = timedelta(days=int(delta)) adjusted_datetime = date.today() + time_delta @@ -131,15 +137,17 @@ def _hints_apply_dict_bytes_to_strings_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, bytes): - key_to_use = str(k, 'utf8') + key_to_use = str(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_bytes_to_strings_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, bytes): - val_to_use = str(v, 'utf8') + val_to_use = str(v, "utf8") output_dict[key_to_use] = val_to_use @@ -159,15 +167,17 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): for k, v in input_dict.items(): key_to_use = k if isinstance(k, str): - key_to_use = bytes(k, 'utf8') + key_to_use = bytes(k, "utf8") if isinstance(v, dict): - output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv(v, modifier) + output_dict[key_to_use] = _hints_apply_dict_strings_to_bytes_kv( + v, modifier + ) continue val_to_use = v if isinstance(v, str): - val_to_use = bytes(v, 'utf8') + val_to_use = bytes(v, "utf8") output_dict[key_to_use] = val_to_use @@ -176,21 +186,21 @@ def _hints_apply_dict_strings_to_bytes_kv(input_dict, modifier): # hints that can be aplied without an additional modifier argument _HINTS_APPLIERS_ARGLESS = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, - 'convert_dict_bytes_to_strings_kv': _hints_apply_dict_bytes_to_strings_kv, - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, + "convert_dict_bytes_to_strings_kv": _hints_apply_dict_bytes_to_strings_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } # hints applicable to the conversion of attributes during fixture loading _FIXTUREFILE_APPLIERS_ATTRIBUTES = { - 'array_of_tuples': _hints_apply_array_of_tuples, - 'today_relative': _hints_apply_today_relative, + "array_of_tuples": _hints_apply_array_of_tuples, + "today_relative": _hints_apply_today_relative, } # hints applied when writing the contents of a fixture as a temporary file _FIXTUREFILE_APPLIERS_ONWRITE = { - 'convert_dict_strings_to_bytes_kv': _hints_apply_dict_strings_to_bytes_kv, + "convert_dict_strings_to_bytes_kv": _hints_apply_dict_strings_to_bytes_kv, } @@ -222,7 +232,7 @@ def _load_hints_ini_for_fixture_if_present(fixture_name): pass # ensure empty required fixture to avoid extra conditionals later - for required_section in ['ATTRIBUTES']: + for required_section in ["ATTRIBUTES"]: if not hints.has_section(required_section): hints.add_section(required_section) @@ -239,10 +249,10 @@ def _hints_apply_from_fixture_ini_if_present(fixture_name, json_object): # apply any attriutes hints ahead of specified conversions such that any # key can be specified matching what is visible within the loaded fixture - for item_name, item_hint_unparsed in hints['ATTRIBUTES'].items(): + for item_name, item_hint_unparsed in hints["ATTRIBUTES"].items(): loaded_value = json_object[item_name] - item_hint_and_maybe_modifier = item_hint_unparsed.split('--') + item_hint_and_maybe_modifier = item_hint_unparsed.split("--") item_hint = item_hint_and_maybe_modifier[0] if len(item_hint_and_maybe_modifier) == 2: modifier = item_hint_and_maybe_modifier[1] @@ -267,7 +277,9 @@ def __init__(self, hint=None, modifier=None, value=None): def decode_hint(hint_obj): """Produce a value based on the properties of a hint instance.""" assert isinstance(hint_obj, _FixtureHint) - value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[hint_obj.hint] + value_from_loaded_value = _FIXTUREFILE_APPLIERS_ATTRIBUTES[ + hint_obj.hint + ] return value_from_loaded_value(hint_obj.value, hint_obj.modifier) @staticmethod @@ -278,11 +290,14 @@ def object_hook(decoded_object): """ if "_FixtureHint" in decoded_object: - fixture_hint = _FixtureHint(decoded_object["hint"], decoded_object["modifier"]) + fixture_hint = _FixtureHint( + decoded_object["hint"], decoded_object["modifier"] + ) return _FixtureHint.decode_hint(fixture_hint) return decoded_object + # @@ -295,7 +310,7 @@ def fixturepath(relative_path): def _to_display_path(value): """Convert an absolute path to one to be shown as part of test output.""" display_path = os.path.relpath(value, MIG_BASE) - if not display_path.startswith('.'): + if not display_path.startswith("."): return "./" + display_path return display_path @@ -307,10 +322,9 @@ class _PreparedFixture: NO_DATA = object() - def __init__(self, testcase, - fixture_name, - fixture_format='', - fixture_data=NO_DATA): + def __init__( + self, testcase, fixture_name, fixture_format="", fixture_data=NO_DATA + ): self.testcase = testcase self.fixture_name = fixture_name self.fixture_format = fixture_format @@ -336,9 +350,12 @@ def assertAgainstFixture(self, value): if self.fixture_format: message_infix = " with format %s" % (self.fixture_format,) else: - message_infix = '' + message_infix = "" message = "value differed from fixture named %s%s\n\n%s" % ( - self.fixture_name, message_infix, raised_exception) + self.fixture_name, + message_infix, + raised_exception, + ) raise AssertionError(message) def write_to_dir(self, target_dir, output_format=None): @@ -347,42 +364,47 @@ def write_to_dir(self, target_dir, output_format=None): directory applying any onwrite hints that may be specified. """ - assert self.fixture_data is not self.NO_DATA, \ - "fixture is not populated with data" + assert ( + self.fixture_data is not self.NO_DATA + ), "fixture is not populated with data" assert os.path.isabs(target_dir) # convert fixture name (which includes the varaint) to the target file - fixture_file_target = _fixturefile_normname(self.fixture_name, prefix=target_dir) + fixture_file_target = _fixturefile_normname( + self.fixture_name, prefix=target_dir + ) output_data = self.fixture_data # now apply any onwrite conversions hints = _load_hints_ini_for_fixture_if_present(self.fixture_name) - for item_name in hints['ONWRITE']: + for item_name in hints["ONWRITE"]: if item_name not in _FIXTUREFILE_APPLIERS_ONWRITE: raise AssertionError( - "unsupported fixture conversion: %s" % (item_name,)) + "unsupported fixture conversion: %s" % (item_name,) + ) - enabled = hints.getboolean('ONWRITE', item_name) + enabled = hints.getboolean("ONWRITE", item_name) if not enabled: continue hint_fn = _FIXTUREFILE_APPLIERS_ONWRITE[item_name] output_data = hint_fn(output_data, None) - if output_format == 'binary': - with open(fixture_file_target, 'wb') as fixture_outputfile: + if output_format == "binary": + with open(fixture_file_target, "wb") as fixture_outputfile: fixture_outputfile.write(output_data) - elif output_format == 'json': - with open(fixture_file_target, 'w') as fixture_outputfile: + elif output_format == "json": + with open(fixture_file_target, "w") as fixture_outputfile: json.dump(output_data, fixture_outputfile) - elif output_format == 'pickle': - with open(fixture_file_target, 'wb') as fixture_outputfile: + elif output_format == "pickle": + with open(fixture_file_target, "wb") as fixture_outputfile: pickle.dump(output_data, fixture_outputfile) else: raise AssertionError( - "unsupported fixture format: %s" % (output_format,)) + "unsupported fixture format: %s" % (output_format,) + ) @staticmethod def from_relpath(testcase, fixture_name, fixture_format): @@ -392,11 +414,16 @@ def from_relpath(testcase, fixture_name, fixture_format): """ fixture_data, fixture_path = _fixturefile_loadrelative( - fixture_name, fixture_format) - return _PreparedFixture(testcase, fixture_name, fixture_format, fixture_data) + fixture_name, fixture_format + ) + return _PreparedFixture( + testcase, fixture_name, fixture_format, fixture_data + ) class FixtureAssertMixin: def prepareFixtureAssert(self, fixture_relpath, fixture_format=None): """Prepare to assert a value against a fixture.""" - return _PreparedFixture.from_relpath(self, fixture_relpath, fixture_format) + return _PreparedFixture.from_relpath( + self, fixture_relpath, fixture_format + ) diff --git a/tests/support/loggersupp.py b/tests/support/loggersupp.py index b1eb2c295..5de13fb8a 100644 --- a/tests/support/loggersupp.py +++ b/tests/support/loggersupp.py @@ -28,9 +28,9 @@ """Logger related details within the test support library.""" -from collections import defaultdict import os import re +from collections import defaultdict from tests.support.suppconst import MIG_BASE, TEST_BASE @@ -47,7 +47,8 @@ class FakeLogger: """ RE_UNCLOSEDFILE = re.compile( - 'unclosed file <.*? name=\'(?P.*?)\'( .*?)?>') + "unclosed file <.*? name='(?P.*?)'( .*?)?>" + ) def __init__(self): self.channels_dict = defaultdict(list) @@ -70,13 +71,19 @@ def check_empty_and_reset(self): # complain loudly (and in detail) in the case of unclosed files if len(unclosed_by_file) > 0: - messages = '\n'.join({' --> %s: line=%s, file=%s' % (fname, lineno, outname) - for fname, (lineno, outname) in unclosed_by_file.items()}) - raise RuntimeError('unclosed files encountered:\n%s' % (messages,)) - - if channels_dict['error'] and not forgive_by_channel['error']: - raise RuntimeError('errors reported to logger:\n%s' % - '\n'.join(channels_dict['error'])) + messages = "\n".join( + { + " --> %s: line=%s, file=%s" % (fname, lineno, outname) + for fname, (lineno, outname) in unclosed_by_file.items() + } + ) + raise RuntimeError("unclosed files encountered:\n%s" % (messages,)) + + if channels_dict["error"] and not forgive_by_channel["error"]: + raise RuntimeError( + "errors reported to logger:\n%s" + % "\n".join(channels_dict["error"]) + ) def forgive_errors(self): """Allow log errors for cases where they are expected""" @@ -90,26 +97,26 @@ def forgive_messages_on(self, *, channel_name=None): def debug(self, line): """Mock log action of same name""" - self._append_as('debug', line) + self._append_as("debug", line) def error(self, line): """Mock log action of same name""" - self._append_as('error', line) + self._append_as("error", line) def info(self, line): """Mock log action of same name""" - self._append_as('info', line) + self._append_as("info", line) def warning(self, line): """Mock log action of same name""" - self._append_as('warning', line) + self._append_as("warning", line) def write(self, message): """Actual write handler""" - channel, namespace, specifics = message.split(':', 2) + channel, namespace, specifics = message.split(":", 2) # ignore everything except warnings sent by the python runtime - if not (channel == 'WARNING' and namespace == 'py.warnings'): + if not (channel == "WARNING" and namespace == "py.warnings"): return filename_and_datatuple = FakeLogger.identify_unclosed_file(specifics) @@ -119,10 +126,10 @@ def write(self, message): @staticmethod def identify_unclosed_file(specifics): """Warn about unclosed files""" - filename, lineno, exc_name, message = specifics.split(':', 3) + filename, lineno, exc_name, message = specifics.split(":", 3) exc_name = exc_name.lstrip() - if exc_name != 'ResourceWarning': + if exc_name != "ResourceWarning": return matched = FakeLogger.RE_UNCLOSEDFILE.match(message.lstrip()) @@ -131,7 +138,8 @@ def identify_unclosed_file(specifics): relative_testfile = os.path.relpath(filename, start=MIG_BASE) relative_outputfile = os.path.relpath( - matched.groups('location')[0], start=TEST_BASE) + matched.groups("location")[0], start=TEST_BASE + ) return (relative_testfile, (lineno, relative_outputfile)) diff --git a/tests/support/picklesupp.py b/tests/support/picklesupp.py index 667dd4b01..262e4c50c 100644 --- a/tests/support/picklesupp.py +++ b/tests/support/picklesupp.py @@ -29,8 +29,8 @@ import pickle -from tests.support.suppconst import TEST_OUTPUT_DIR from tests.support.fixturesupp import _HINTS_APPLIERS_ARGLESS +from tests.support.suppconst import TEST_OUTPUT_DIR class PickleAssertMixin: @@ -44,7 +44,7 @@ def assertPickledFile(self, pickle_file_path, apply_hints=None): having been optionally transformed as requested by hints. """ - with open(pickle_file_path, 'rb') as picklefile: + with open(pickle_file_path, "rb") as picklefile: pickled = pickle.load(picklefile) if not apply_hints: diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 0e0fd4b94..bdca0fa33 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -27,7 +27,8 @@ """Server threading related details within the test support library""" -from threading import Thread, Event as ThreadEvent +from threading import Event as ThreadEvent +from threading import Thread class ServerWithinThreadExecutor: @@ -50,7 +51,7 @@ def run(self): """Mimic the same method from the standard thread API""" server_args, server_kwargs = self._arguments - server_kwargs['on_start'] = lambda _: self._started.set() + server_kwargs["on_start"] = lambda _: self._started.set() self._wrapped = self._serverclass(*server_args, **server_kwargs) diff --git a/tests/support/snapshotsupp.py b/tests/support/snapshotsupp.py index 355095814..a24bb2238 100644 --- a/tests/support/snapshotsupp.py +++ b/tests/support/snapshotsupp.py @@ -28,14 +28,14 @@ import difflib import errno -import re import os +import re from tests.support.suppconst import TEST_BASE -HTML_TAG = '' -MARKER_CONTENT_BEGIN = '' -MARKER_CONTENT_END = '' +HTML_TAG = "" +MARKER_CONTENT_BEGIN = "" +MARKER_CONTENT_END = "" TEST_SNAPSHOTS_DIR = os.path.join(TEST_BASE, "snapshots") try: @@ -57,9 +57,9 @@ def _html_content_only(value): # set the index after the content marker content_start_index += len(MARKER_CONTENT_BEGIN) # we now need to remove the container div inside it ..first find it - content_start_inner_div = value.find('', content_start_inner_div) + 1 + content_start_index = value.find(">", content_start_inner_div) + 1 content_end_index = value.find(MARKER_CONTENT_END) assert content_end_index > -1, "unable to locate end of content" @@ -77,7 +77,7 @@ def _delimited_lines(value): lines = [] while from_index < last_index: - found_index = value.find('\n', from_index) + found_index = value.find("\n", from_index) if found_index == -1: break found_index += 1 @@ -93,8 +93,8 @@ def _delimited_lines(value): def _force_refresh_snapshots(): """Check whether the environment specifies snapshots should be refreshed.""" - env_refresh_snapshots = os.environ.get('REFRESH_SNAPSHOTS', 'no').lower() - return env_refresh_snapshots in ('true', 'yes', '1') + env_refresh_snapshots = os.environ.get("REFRESH_SNAPSHOTS", "no").lower() + return env_refresh_snapshots in ("true", "yes", "1") class SnapshotAssertMixin: @@ -107,7 +107,7 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): In the case a snapshot does not exist it is saved on first invocation. """ - file_name = ''.join([self._testMethodName, ".", extension]) + file_name = "".join([self._testMethodName, ".", extension]) file_path = os.path.join(TEST_SNAPSHOTS_DIR, file_name) if not os.path.isfile(file_path) or _force_refresh_snapshots(): @@ -126,11 +126,12 @@ def _snapshotsupp_compare_snapshot(self, extension, actual_content): udiff = difflib.unified_diff( _delimited_lines(expected_content), _delimited_lines(actual_content), - 'expected', - 'actual' + "expected", + "actual", ) raise AssertionError( - "content did not match snapshot\n\n%s" % (''.join(udiff),)) + "content did not match snapshot\n\n%s" % ("".join(udiff),) + ) def assertSnapshot(self, actual_content, extension=None): """Load a snapshot corresponding to the named test and check that what @@ -148,4 +149,4 @@ def assertSnapshotOfHtmlContent(self, actual_content): """ actual_content = _html_content_only(actual_content) - self._snapshotsupp_compare_snapshot('html', actual_content) + self._snapshotsupp_compare_snapshot("html", actual_content) diff --git a/tests/support/suppconst.py b/tests/support/suppconst.py index 148303f0d..204b0ba29 100644 --- a/tests/support/suppconst.py +++ b/tests/support/suppconst.py @@ -29,11 +29,11 @@ from tests.support._env import MIG_ENV -if MIG_ENV == 'local': +if MIG_ENV == "local": # Use abspath for __file__ on Py2 _SUPPORT_DIR = os.path.dirname(os.path.abspath(__file__)) -elif MIG_ENV == 'docker': - _SUPPORT_DIR = '/usr/src/app/tests/support' +elif MIG_ENV == "docker": + _SUPPORT_DIR = "/usr/src/app/tests/support" else: raise NotImplementedError("ABORT: unsupported environment: %s" % (MIG_ENV,)) @@ -46,7 +46,8 @@ ENVHELP_OUTPUT_DIR = os.path.join(ENVHELP_DIR, "output") -if __name__ == '__main__': +if __name__ == "__main__": + def print_root_relative(prefix, path): print("%s = /%s" % (prefix, os.path.relpath(path, MIG_BASE))) diff --git a/tests/support/usersupp.py b/tests/support/usersupp.py index 65a02fa4a..849c98f92 100644 --- a/tests/support/usersupp.py +++ b/tests/support/usersupp.py @@ -33,15 +33,11 @@ import pickle from mig.shared.base import client_id_dir - from tests.support.fixturesupp import _PreparedFixture +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - -_FIXTURE_NAME_BY_USER_DN = { - TEST_USER_DN: 'MiG-users.db--example' -} +_FIXTURE_NAME_BY_USER_DN = {TEST_USER_DN: "MiG-users.db--example"} class UserAssertMixin: @@ -59,9 +55,9 @@ def _provision_user_db_dir(testcase): conf_user_db_home = testcase.configuration.user_db_home os.makedirs(conf_user_db_home, exist_ok=True) - user_db_file = os.path.join(conf_user_db_home, 'MiG-users.db') + user_db_file = os.path.join(conf_user_db_home, "MiG-users.db") if os.path.exists(user_db_file): - raise AssertionError('a user database file already exists') + raise AssertionError("a user database file already exists") return conf_user_db_home @@ -79,19 +75,19 @@ def _provision_test_user(testcase, distinguished_name): try: fixture_relpath = _FIXTURE_NAME_BY_USER_DN[distinguished_name] except KeyError: - raise AssertionError('supplied test user is not known as a fixture') + raise AssertionError("supplied test user is not known as a fixture") # note: this is a non-standard direct use of fixture preparation due # to this being bootstrap code and should not be used elsewhere prepared_fixture = _PreparedFixture.from_relpath( - testcase, - fixture_relpath, - fixture_format='json' + testcase, fixture_relpath, fixture_format="json" ) # write out the user database fixture containing the user - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") - test_user_dir = UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + test_user_dir = UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) return test_user_dir @@ -112,12 +108,16 @@ def _provision_test_user_dirs(testcase, distinguished_name): # create the test user settings directory conf_user_settings = os.path.normpath(self.configuration.user_settings) - test_user_settings_dir = os.path.join(conf_user_settings, test_client_dir_name) + test_user_settings_dir = os.path.join( + conf_user_settings, test_client_dir_name + ) os.makedirs(test_user_settings_dir) # create an empty user settings file - test_user_settings_file = os.path.join(test_user_settings_dir, 'settings') - with open(test_user_settings_file, 'wb') as outfile: + test_user_settings_file = os.path.join( + test_user_settings_dir, "settings" + ) + with open(test_user_settings_file, "wb") as outfile: pickle.dump({}, outfile) return test_user_dir @@ -154,9 +154,11 @@ def _provision_test_users(testcase, *distinguished_names): # write out all the users we have assembled by populating an empty # fixture with their data but using a known fixture name and thus one # suitably hinted so a production format pickle file ends up on-disk - prepared_fixture = _PreparedFixture(testcase, 'MiG-users.db--example') + prepared_fixture = _PreparedFixture(testcase, "MiG-users.db--example") prepared_fixture.fixture_data = users_by_dn - prepared_fixture.write_to_dir(conf_user_db_home, output_format='pickle') + prepared_fixture.write_to_dir(conf_user_db_home, output_format="pickle") for distinguished_name in distinguished_names: - UserAssertMixin._provision_test_user_dirs(testcase, distinguished_name) + UserAssertMixin._provision_test_user_dirs( + testcase, distinguished_name + ) diff --git a/tests/support/wsgisupp.py b/tests/support/wsgisupp.py index 1105d0db8..2a6a209d2 100644 --- a/tests/support/wsgisupp.py +++ b/tests/support/wsgisupp.py @@ -27,22 +27,21 @@ """Test support library for WSGI.""" -from collections import namedtuple import codecs +from collections import namedtuple from io import BytesIO from urllib.parse import urlencode, urlparse from werkzeug.datastructures import MultiDict - # named type representing the tuple that is passed to WSGI handlers -_PreparedWsgi = namedtuple('_PreparedWsgi', ['environ', 'start_response']) +_PreparedWsgi = namedtuple("_PreparedWsgi", ["environ", "start_response"]) class FakeWsgiStartResponse: """Glue object that conforms to the same interface as the start_response() - in the WSGI specs but records the calls to it such that they can be - inspected and, for our purposes, asserted against.""" + in the WSGI specs but records the calls to it such that they can be + inspected and, for our purposes, asserted against.""" def __init__(self): self.calls = [] @@ -51,7 +50,9 @@ def __call__(self, status, headers, exc=None): self.calls.append((status, headers, exc)) -def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, headers=None, form=None): +def create_wsgi_environ( + configuration, wsgi_url, method="GET", query=None, headers=None, form=None +): """Populate the necessary variables that will constitute a valid WSGI environment given a URL to which we will make a requests under test and various other options that set up the nature of that request.""" @@ -59,21 +60,21 @@ def create_wsgi_environ(configuration, wsgi_url, method='GET', query=None, heade parsed_url = urlparse(wsgi_url) if query: - method = 'GET' + method = "GET" request_query = urlencode(query) wsgi_input = () elif form: - method = 'POST' - request_query = '' + method = "POST" + request_query = "" - body = urlencode(MultiDict(form)).encode('ascii') + body = urlencode(MultiDict(form)).encode("ascii") headers = headers or {} - if not 'Content-Type' in headers: - headers['Content-Type'] = 'application/x-www-form-urlencoded' + if not "Content-Type" in headers: + headers["Content-Type"] = "application/x-www-form-urlencoded" - headers['Content-Length'] = str(len(body)) + headers["Content-Length"] = str(len(body)) wsgi_input = BytesIO(body) else: request_query = parsed_url.query @@ -83,26 +84,27 @@ class _errors: """Internal helper to ignore wsgi.errors close method calls""" def close(self, *ars, **kwargs): - """"Simply ignore""" + """ "Simply ignore""" pass environ = {} - environ['wsgi.errors'] = _errors() - environ['wsgi.input'] = wsgi_input - environ['wsgi.url_scheme'] = parsed_url.scheme - environ['wsgi.version'] = (1, 0) - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = parsed_url.netloc - environ['PATH_INFO'] = parsed_url.path - environ['QUERY_STRING'] = request_query - environ['REQUEST_METHOD'] = method - environ['SCRIPT_URI'] = ''.join( - ('http://', environ['HTTP_HOST'], environ['PATH_INFO'])) + environ["wsgi.errors"] = _errors() + environ["wsgi.input"] = wsgi_input + environ["wsgi.url_scheme"] = parsed_url.scheme + environ["wsgi.version"] = (1, 0) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = parsed_url.netloc + environ["PATH_INFO"] = parsed_url.path + environ["QUERY_STRING"] = request_query + environ["REQUEST_METHOD"] = method + environ["SCRIPT_URI"] = "".join( + ("http://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) if headers: for k, v in headers.items(): - header_key = k.replace('-', '_').upper() - if header_key.startswith('CONTENT'): + header_key = k.replace("-", "_").upper() + if header_key.startswith("CONTENT"): # Content-* headers must not be prefixed in WSGI pass else: @@ -119,15 +121,15 @@ def create_wsgi_start_response(): def prepare_wsgi(configuration, url, **kwargs): return _PreparedWsgi( create_wsgi_environ(configuration, url, **kwargs), - create_wsgi_start_response() + create_wsgi_start_response(), ) def _trigger_and_unpack_result(wsgi_result): chunks = list(wsgi_result) assert len(chunks) > 0, "invocation returned no output" - complete_value = b''.join(chunks) - decoded_value = codecs.decode(complete_value, 'utf8') + complete_value = b"".join(chunks) + decoded_value = codecs.decode(complete_value, "utf8") return decoded_value @@ -140,7 +142,7 @@ def assertWsgiResponse(self, wsgi_result, fake_wsgi, expected_status_code): content = _trigger_and_unpack_result(wsgi_result) def called_once(fake): - assert hasattr(fake, 'calls') + assert hasattr(fake, "calls") return len(fake.calls) == 1 fake_start_response = fake_wsgi.start_response diff --git a/tests/test_bin_verifyvgridformat.py b/tests/test_bin_verifyvgridformat.py index 3a581c112..25578abf0 100644 --- a/tests/test_bin_verifyvgridformat.py +++ b/tests/test_bin_verifyvgridformat.py @@ -34,6 +34,7 @@ from unittest.mock import patch from bin.verifyvgridformat import verify_vgrid_format + # Imports required for building the vgrid test fixtures from mig.shared.vgrid import vgrid_flat_name @@ -197,9 +198,7 @@ def test_specific_vgrid_name_modern_format_returns_true(self): """Verify returns True when targeting a named vgrid in modern format.""" self._make_vgrid("testvgrid") self._make_vgrid_files_symlink("testvgrid") - result = verify_vgrid_format( - self.configuration, vgrid_name="testvgrid" - ) + result = verify_vgrid_format(self.configuration, vgrid_name="testvgrid") self.assertTrue(result) def test_specific_vgrid_name_legacy_format_returns_true(self): @@ -219,9 +218,7 @@ def test_verbose_no_vgrids_produces_no_output(self): """Verify verbose mode with no vgrids produces no output and returns True.""" captured = io.StringIO() with patch("sys.stdout", captured): - result = verify_vgrid_format( - self.configuration, verbose=True - ) + result = verify_vgrid_format(self.configuration, verbose=True) self.assertTrue(result) self.assertEqual(captured.getvalue(), "") @@ -231,9 +228,7 @@ def test_verbose_modern_format_reports_up_to_date(self): self._make_vgrid_files_symlink("testvgrid") captured = io.StringIO() with patch("sys.stdout", captured): - result = verify_vgrid_format( - self.configuration, verbose=True - ) + result = verify_vgrid_format(self.configuration, verbose=True) self.assertTrue(result) self.assertIn("up to date", captured.getvalue()) @@ -243,9 +238,7 @@ def test_verbose_legacy_format_reports_reformat_commands(self): self._make_vgrid_files_dir("testvgrid") captured = io.StringIO() with patch("sys.stdout", captured): - result = verify_vgrid_format( - self.configuration, verbose=True - ) + result = verify_vgrid_format(self.configuration, verbose=True) self.assertTrue(result) output = captured.getvalue() self.assertIn("legacy format", output) diff --git a/tests/test_booleans.py b/tests/test_booleans.py index 5246197ee..a3564b8ff 100644 --- a/tests/test_booleans.py +++ b/tests/test_booleans.py @@ -2,6 +2,7 @@ from tests.support import MigTestCase, testmain + class TestBooleans(MigTestCase): def test_true(self): self.assertEqual(True, True) @@ -10,5 +11,5 @@ def test_false(self): self.assertEqual(False, False) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_install_generateconfs.py b/tests/test_mig_install_generateconfs.py index e8ed90241..c68318cd9 100644 --- a/tests/test_mig_install_generateconfs.py +++ b/tests/test_mig_install_generateconfs.py @@ -33,13 +33,13 @@ import os import sys -from tests.support import MIG_BASE, MigTestCase, testmain, cleanpath +from tests.support import MIG_BASE, MigTestCase, cleanpath, testmain def _import_generateconfs(): """Internal helper to work around non-package import location""" - sys.path.append(os.path.join(MIG_BASE, 'mig/install')) - mod = importlib.import_module('generateconfs') + sys.path.append(os.path.join(MIG_BASE, "mig/install")) + mod = importlib.import_module("generateconfs") sys.path.pop(-1) return mod @@ -51,12 +51,14 @@ def _import_generateconfs(): def create_fake_generate_confs(return_dict=None): """Fake generate confs helper""" + def _generate_confs(*args, **kwargs): _generate_confs.settings = kwargs if return_dict: return (return_dict, {}) else: return ({}, {}) + _generate_confs.settings = None return _generate_confs @@ -69,52 +71,64 @@ class MigInstallGenerateconfs__main(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" def test_option_permanent_freeze(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--permanent_freeze', 'yes'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--permanent_freeze", "yes"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) def test_option_storage_protocols(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--storage_protocols', 'proto1 proto2 proto3'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--storage_protocols", "proto1 proto2 proto3"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) self.assertEqual(exit_code, 0) settings = fake_generate_confs.settings - self.assertIn('storage_protocols', settings) - self.assertEqual(settings['storage_protocols'], 'proto1 proto2 proto3') + self.assertIn("storage_protocols", settings) + self.assertEqual(settings["storage_protocols"], "proto1 proto2 proto3") def test_option_wwwserve_max_bytes(self): - expected_generated_dir = cleanpath('confs-stdlocal', self, - ensure_dir=True) - with open(os.path.join(expected_generated_dir, "instructions.txt"), - "w"): + expected_generated_dir = cleanpath( + "confs-stdlocal", self, ensure_dir=True + ) + with open( + os.path.join(expected_generated_dir, "instructions.txt"), "w" + ): pass fake_generate_confs = create_fake_generate_confs( - dict(destination_dir=expected_generated_dir)) - test_arguments = ['--wwwserve_max_bytes', '43211234'] + dict(destination_dir=expected_generated_dir) + ) + test_arguments = ["--wwwserve_max_bytes", "43211234"] exit_code = main( - test_arguments, _generate_confs=fake_generate_confs, _print=noop) + test_arguments, _generate_confs=fake_generate_confs, _print=noop + ) settings = fake_generate_confs.settings - self.assertIn('wwwserve_max_bytes', settings) - self.assertEqual(settings['wwwserve_max_bytes'], 43211234) + self.assertIn("wwwserve_max_bytes", settings) + self.assertEqual(settings["wwwserve_max_bytes"], 43211234) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_lib_accounting.py b/tests/test_mig_lib_accounting.py index fc0003fc4..9eb60a931 100644 --- a/tests/test_mig_lib_accounting.py +++ b/tests/test_mig_lib_accounting.py @@ -30,8 +30,11 @@ import os import pickle -from mig.lib.accounting import get_usage, human_readable_filesize, \ - update_accounting +from mig.lib.accounting import ( + get_usage, + human_readable_filesize, + update_accounting, +) from mig.shared.base import client_id_dir from mig.shared.defaults import peers_filename from tests.support import MigTestCase, ensure_dirs_exist @@ -39,71 +42,90 @@ TEST_MTIME = 1768925307 TEST_SOFTLIMIT_BYTES = 109951162777600 TEST_HARDLIMIT_BYTES = 109951162777600 -TEST_CLIENT_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com' +TEST_CLIENT_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@user.com" +) TEST_CLIENT_BYTES = 206128256 -TEST_EXT_DN = '/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com' +TEST_EXT_DN = "/C=DK/ST=NA/L=NA/O=PEER Org/OU=NA/CN=Test Peer/emailAddress=peer@example.com" TEST_EXT_BYTES = 16806128256 TEST_FREEZE_BYTES = 128256 -TEST_VGRID_NAME1 = 'TestVgrid1' +TEST_VGRID_NAME1 = "TestVgrid1" TEST_VGRID_BYTES1 = 406128256 -TEST_VGRID_NAME2 = 'TestVgrid2' +TEST_VGRID_NAME2 = "TestVgrid2" TEST_VGRID_BYTES2 = 606128256 -TEST_VGRID_NAME3 = 'TestVgrid3' +TEST_VGRID_NAME3 = "TestVgrid3" TEST_VGRID_BYTES3 = 806128256 -TEST_VGRID_TOTAL_BYTES = TEST_VGRID_BYTES1 \ - + TEST_VGRID_BYTES2 \ - + TEST_VGRID_BYTES3 -TEST_TOTAL_BYTES = TEST_CLIENT_BYTES \ - + TEST_EXT_BYTES \ - + TEST_FREEZE_BYTES \ +TEST_VGRID_TOTAL_BYTES = ( + TEST_VGRID_BYTES1 + TEST_VGRID_BYTES2 + TEST_VGRID_BYTES3 +) +TEST_TOTAL_BYTES = ( + TEST_CLIENT_BYTES + + TEST_EXT_BYTES + + TEST_FREEZE_BYTES + TEST_VGRID_TOTAL_BYTES -TEST_LUSTRE_QUOTA_INFO = {'next_pid': 192, 'mtime': TEST_MTIME} -TEST_CLIENT_USAGE = {'lustre_pid': 42, - 'files': 11, - 'bytes': TEST_CLIENT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE1 = {'lustre_pid': 43, - 'files': 111, - 'bytes': TEST_VGRID_BYTES1, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE2 = {'lustre_pid': 44, - 'files': 222, - 'bytes': TEST_VGRID_BYTES2, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_VGRID_USAGE3 = {'lustre_pid': 45, - 'files': 333, - 'bytes': TEST_VGRID_BYTES3, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_EXT_USAGE = {'lustre_pid': 46, - 'files': 1, - 'bytes': TEST_EXT_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_FREEZE_USAGE = {'lustre_pid': 47, - 'files': 1, - 'bytes': TEST_FREEZE_BYTES, - 'softlimit_bytes': TEST_SOFTLIMIT_BYTES, - 'hardlimit_bytes': TEST_HARDLIMIT_BYTES, - 'mtime': TEST_MTIME} -TEST_PEERS = {TEST_EXT_DN: {'kind': 'collaboration', - 'distinguished_name': TEST_EXT_DN, - 'country': 'DK', - 'label': 'TEST', - 'state': '', - 'expire': '2222-12-31', - 'full_name': 'Test Peer', - 'organization': 'PEER Org', - 'email': 'peer@example.com' - }} +) +TEST_LUSTRE_QUOTA_INFO = {"next_pid": 192, "mtime": TEST_MTIME} +TEST_CLIENT_USAGE = { + "lustre_pid": 42, + "files": 11, + "bytes": TEST_CLIENT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE1 = { + "lustre_pid": 43, + "files": 111, + "bytes": TEST_VGRID_BYTES1, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE2 = { + "lustre_pid": 44, + "files": 222, + "bytes": TEST_VGRID_BYTES2, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_VGRID_USAGE3 = { + "lustre_pid": 45, + "files": 333, + "bytes": TEST_VGRID_BYTES3, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_EXT_USAGE = { + "lustre_pid": 46, + "files": 1, + "bytes": TEST_EXT_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_FREEZE_USAGE = { + "lustre_pid": 47, + "files": 1, + "bytes": TEST_FREEZE_BYTES, + "softlimit_bytes": TEST_SOFTLIMIT_BYTES, + "hardlimit_bytes": TEST_HARDLIMIT_BYTES, + "mtime": TEST_MTIME, +} +TEST_PEERS = { + TEST_EXT_DN: { + "kind": "collaboration", + "distinguished_name": TEST_EXT_DN, + "country": "DK", + "label": "TEST", + "state": "", + "expire": "2222-12-31", + "full_name": "Test Peer", + "organization": "PEER Org", + "email": "peer@example.com", + } +} class MigLibAccounting(MigTestCase): @@ -111,7 +133,7 @@ class MigLibAccounting(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -120,15 +142,17 @@ def before_each(self): self.configuration.site_enable_quota = True self.configuration.site_enable_accounting = True - self.configuration.quota_backend = 'lustre' - - quota_basepath = os.path.join(self.configuration.quota_home, - self.configuration.quota_backend) - quota_user_path = os.path.join(quota_basepath, 'user') - quota_vgrid_path = os.path.join(quota_basepath, 'vgrid') - quota_freeze_path = os.path.join(quota_basepath, 'freeze') - test_client_peers_path = os.path.join(self.configuration.user_settings, - client_id_dir(TEST_CLIENT_DN)) + self.configuration.quota_backend = "lustre" + + quota_basepath = os.path.join( + self.configuration.quota_home, self.configuration.quota_backend + ) + quota_user_path = os.path.join(quota_basepath, "user") + quota_vgrid_path = os.path.join(quota_basepath, "vgrid") + quota_freeze_path = os.path.join(quota_basepath, "freeze") + test_client_peers_path = os.path.join( + self.configuration.user_settings, client_id_dir(TEST_CLIENT_DN) + ) ensure_dirs_exist(self.configuration.vgrid_home) ensure_dirs_exist(self.configuration.user_settings) @@ -141,61 +165,69 @@ def before_each(self): # Ensure fake vgrid and write owner - for vgrid_name in [TEST_VGRID_NAME1, - TEST_VGRID_NAME2, - TEST_VGRID_NAME3]: + for vgrid_name in [ + TEST_VGRID_NAME1, + TEST_VGRID_NAME2, + TEST_VGRID_NAME3, + ]: vgrid_home_path = os.path.join( - self.configuration.vgrid_home, vgrid_name) + self.configuration.vgrid_home, vgrid_name + ) ensure_dirs_exist(vgrid_home_path) - vgrid_owners_filepath = os.path.join(vgrid_home_path, 'owners') - with open(vgrid_owners_filepath, 'wb') as fh: + vgrid_owners_filepath = os.path.join(vgrid_home_path, "owners") + with open(vgrid_owners_filepath, "wb") as fh: fh.write(pickle.dumps([TEST_CLIENT_DN])) # Write fake quota - test_lustre_quota_info_filepath \ - = os.path.join(self.configuration.quota_home, - '%s.pck' % self.configuration.quota_backend) - with open(test_lustre_quota_info_filepath, 'wb') as fh: + test_lustre_quota_info_filepath = os.path.join( + self.configuration.quota_home, + "%s.pck" % self.configuration.quota_backend, + ) + with open(test_lustre_quota_info_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_LUSTRE_QUOTA_INFO)) - quota_test_client_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_CLIENT_DN)) + quota_test_client_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) - with open(quota_test_client_path, 'wb') as fh: + with open(quota_test_client_path, "wb") as fh: fh.write(pickle.dumps(TEST_CLIENT_USAGE)) - quot_test_vgrid_filepath1 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME1) - with open(quot_test_vgrid_filepath1, 'wb') as fh: + quot_test_vgrid_filepath1 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME1 + ) + with open(quot_test_vgrid_filepath1, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE1)) - quot_test_vgrid_filepath2 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME2) - with open(quot_test_vgrid_filepath2, 'wb') as fh: + quot_test_vgrid_filepath2 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME2 + ) + with open(quot_test_vgrid_filepath2, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE2)) - quot_test_vgrid_filepath3 = os.path.join(quota_vgrid_path, - "%s.pck" % TEST_VGRID_NAME3) - with open(quot_test_vgrid_filepath3, 'wb') as fh: + quot_test_vgrid_filepath3 = os.path.join( + quota_vgrid_path, "%s.pck" % TEST_VGRID_NAME3 + ) + with open(quot_test_vgrid_filepath3, "wb") as fh: fh.write(pickle.dumps(TEST_VGRID_USAGE3)) test_client_peers_filepath = os.path.join( - test_client_peers_path, peers_filename) - with open(test_client_peers_filepath, 'wb') as fh: + test_client_peers_path, peers_filename + ) + with open(test_client_peers_filepath, "wb") as fh: fh.write(pickle.dumps(TEST_PEERS)) - quota_test_client_ext_path \ - = os.path.join(quota_user_path, - "%s.pck" % client_id_dir(TEST_EXT_DN)) - with open(quota_test_client_ext_path, 'wb') as fh: + quota_test_client_ext_path = os.path.join( + quota_user_path, "%s.pck" % client_id_dir(TEST_EXT_DN) + ) + with open(quota_test_client_ext_path, "wb") as fh: fh.write(pickle.dumps(TEST_EXT_USAGE)) - quota_test_freeze_path = os.path.join(quota_freeze_path, - "%s.pck" - % client_id_dir(TEST_CLIENT_DN)) - with open(quota_test_freeze_path, 'wb') as fh: + quota_test_freeze_path = os.path.join( + quota_freeze_path, "%s.pck" % client_id_dir(TEST_CLIENT_DN) + ) + with open(quota_test_freeze_path, "wb") as fh: fh.write(pickle.dumps(TEST_FREEZE_USAGE)) def test_accounting(self): @@ -211,36 +243,43 @@ def test_accounting(self): usage = get_usage(self.configuration) self.assertNotEqual(usage, {}) - accounting = usage.get('accounting', {}) + accounting = usage.get("accounting", {}) test_user_accounting = accounting.get(TEST_CLIENT_DN, {}) self.assertNotEqual(test_user_accounting, {}) - home_total = test_user_accounting.get('home_total', 0) + home_total = test_user_accounting.get("home_total", 0) self.assertEqual(home_total, TEST_CLIENT_BYTES) - vgrid_total = test_user_accounting.get('vgrid_total', 0) + vgrid_total = test_user_accounting.get("vgrid_total", 0) self.assertEqual(vgrid_total, TEST_VGRID_TOTAL_BYTES) - ext_users_total = test_user_accounting.get('ext_users_total', 0) + ext_users_total = test_user_accounting.get("ext_users_total", 0) self.assertEqual(ext_users_total, TEST_EXT_BYTES) - freeze_total = test_user_accounting.get('freeze_total', 0) + freeze_total = test_user_accounting.get("freeze_total", 0) self.assertEqual(freeze_total, TEST_FREEZE_BYTES) - total_bytes = test_user_accounting.get('total_bytes', 0) + total_bytes = test_user_accounting.get("total_bytes", 0) self.assertEqual(total_bytes, TEST_TOTAL_BYTES) def test_human_readable_filesize_valid(self): """Test human-friendly format helper success on valid byte sizes""" - valid = [(0, "0 B"), (42, "42.000 B"), (2**10, "1.000 KiB"), - (2**30, "1.000 GiB"), (2**50, "1.000 PiB"), - (2**89, "512.000 YiB"), (2**90 - 2**70, "1023.999 YiB")] - for (size, expect) in valid: + valid = [ + (0, "0 B"), + (42, "42.000 B"), + (2**10, "1.000 KiB"), + (2**30, "1.000 GiB"), + (2**50, "1.000 PiB"), + (2**89, "512.000 YiB"), + (2**90 - 2**70, "1023.999 YiB"), + ] + for size, expect in valid: self.assertEqual(human_readable_filesize(size), expect) def test_human_readable_filesize_invalid(self): """Test human-friendly format helper failure on invalid byte sizes""" - invalid = [(i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, - 2**128]] - for (size, expect) in invalid: + invalid = [ + (i, "NaN") for i in [False, None, "", "one", -1, 1.2, 2**90, 2**128] + ] + for size, expect in invalid: self.assertEqual(human_readable_filesize(size), expect) diff --git a/tests/test_mig_lib_daemon.py b/tests/test_mig_lib_daemon.py index bd8cd901c..0cd55ea4e 100644 --- a/tests/test_mig_lib_daemon.py +++ b/tests/test_mig_lib_daemon.py @@ -31,10 +31,22 @@ import signal import time -from mig.lib.daemon import _run_event, _stop_event, check_run, check_stop, \ - do_run, interruptible_sleep, register_run_handler, register_stop_handler, \ - reset_run, reset_stop, run_handler, stop_handler, stop_running, \ - unregister_signal_handlers +from mig.lib.daemon import ( + _run_event, + _stop_event, + check_run, + check_stop, + do_run, + interruptible_sleep, + register_run_handler, + register_stop_handler, + reset_run, + reset_stop, + run_handler, + stop_handler, + stop_running, + unregister_signal_handlers, +) from tests.support import FakeConfiguration, FakeLogger, MigTestCase @@ -42,19 +54,25 @@ class MigLibDaemon(MigTestCase): """Unit tests for daemon related helper functions""" # Signals registered across the tests and explicitly unregistered on init - _used_signals = [signal.SIGCONT, signal.SIGINT, signal.SIGALRM, - signal.SIGABRT, signal.SIGUSR1, signal.SIGUSR2] + _used_signals = [ + signal.SIGCONT, + signal.SIGINT, + signal.SIGALRM, + signal.SIGABRT, + signal.SIGUSR1, + signal.SIGUSR2, + ] def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up any test configuration and reset state before each test""" # Create dummy sig and frame values for isolated test use - self.sig = 'SIGNAL' - self.frame = 'FRAME' + self.sig = "SIGNAL" + self.frame = "FRAME" # Reset event states reset_run() @@ -94,7 +112,7 @@ def test_interruptible_sleep(self): max_secs = 4.2 start = time.time() signal.alarm(1) - interruptible_sleep(self.configuration, max_secs, (check_run, )) + interruptible_sleep(self.configuration, max_secs, (check_run,)) self.assertTrue(check_run()) end = time.time() self.assertTrue(end - start < max_secs) @@ -303,14 +321,17 @@ def test_concurrent_event_handling(self): def test_interruptible_sleep_immediate_break(self): """Test interruptible_sleep with immediate break condition""" + def immediate_true(): return True start = time.time() interruptible_sleep(self.configuration, 5.0, [immediate_true]) duration = time.time() - start - self.assertTrue(duration < 0.1, - "Sleep should exit immediately but took %s" % duration) + self.assertTrue( + duration < 0.1, + "Sleep should exit immediately but took %s" % duration, + ) def test_reset_event_helpers(self): """Test simple event reset helpers""" @@ -337,36 +358,41 @@ def test_unregister_signal_handlers_explicit(self): register_stop_handler(self.configuration, signal.SIGABRT) # Verify handlers were set - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGABRT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGABRT).__name__, "stop_handler" + ) # Unregister specific signals - unregister_signal_handlers(self.configuration, [signal.SIGALRM, - signal.SIGABRT]) + unregister_signal_handlers( + self.configuration, [signal.SIGALRM, signal.SIGABRT] + ) self.assertEqual(signal.getsignal(signal.SIGALRM), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGABRT), signal.SIG_IGN) def test_interruptible_sleep_condition_after_interval(self): """Test interruptible_sleep break condition after one interval""" - state = {'count': 0} + state = {"count": 0} def counter_condition(): - state['count'] += 1 - return state['count'] >= 2 + state["count"] += 1 + return state["count"] >= 2 start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.2, delta=0.15) def test_interruptible_sleep_maxsecs_equals_napsecs(self): """Test interruptible_sleep with max_secs exactly matching nap_secs""" start = time.time() - interruptible_sleep(self.configuration, 0.1, [lambda: False], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 0.1, [lambda: False], nap_secs=0.1 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) @@ -379,8 +405,9 @@ def faulty_condition(): self.logger.error(SLEEP_ERR) start = time.time() - interruptible_sleep(self.configuration, 0.1, [faulty_condition], - nap_secs=0.01) + interruptible_sleep( + self.configuration, 0.1, [faulty_condition], nap_secs=0.01 + ) duration = time.time() - start self.assertAlmostEqual(duration, 0.1, delta=0.05) try: @@ -402,22 +429,26 @@ def test_unregister_default_signals(self): self.assertEqual(signal.getsignal(signal.SIGCONT), signal.SIG_IGN) self.assertEqual(signal.getsignal(signal.SIGUSR2), signal.SIG_IGN) # Verify custom signals remain after default unregister - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'run_handler') - self.assertEqual(signal.getsignal(signal.SIGALRM).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "run_handler" + ) + self.assertEqual( + signal.getsignal(signal.SIGALRM).__name__, "stop_handler" + ) def test_register_default_signal(self): """Test handler registration with default signal values""" # Run handler should default to SIGCONT register_run_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGCONT).__name__, - 'run_handler') + self.assertEqual( + signal.getsignal(signal.SIGCONT).__name__, "run_handler" + ) # Stop handler should default to SIGINT register_stop_handler(self.configuration) - self.assertEqual(signal.getsignal(signal.SIGINT).__name__, - 'stop_handler') + self.assertEqual( + signal.getsignal(signal.SIGINT).__name__, "stop_handler" + ) def test_reset_unregistered_signals(self): """Test unregister responds gracefully to previously unregistered signals""" @@ -440,21 +471,22 @@ def test_interruptible_sleep_break_not_callable(self): def test_interruptible_sleep_all_conditions_checked(self): """Verify all break conditions are checked each sleep interval""" - counter = {'count': 0} + counter = {"count": 0} max_checks = 3 def counter_condition(): - if counter['count'] < max_checks: - counter['count'] += 1 - return counter['count'] >= max_checks + if counter["count"] < max_checks: + counter["count"] += 1 + return counter["count"] >= max_checks start = time.time() - interruptible_sleep(self.configuration, 5.0, [counter_condition], - nap_secs=0.1) + interruptible_sleep( + self.configuration, 5.0, [counter_condition], nap_secs=0.1 + ) duration = time.time() - start # Should run for ~0.3 sec (3 naps of 0.1 sec) self.assertAlmostEqual(duration, 0.3, delta=0.15) - self.assertEqual(counter['count'], max_checks) + self.assertEqual(counter["count"], max_checks) def test_interruptible_sleep_naps_remaining(self): """Test interruptible_sleep counts down remaining naps correctly""" @@ -517,21 +549,20 @@ def test_event_state_persistence(self): def test_signal_handler_dispatch(self): """Verify signal handlers dispatch correct signals""" - test_signals = { - 'run': [signal.SIGUSR1], - 'stop': [signal.SIGUSR2] - } + test_signals = {"run": [signal.SIGUSR1], "stop": [signal.SIGUSR2]} - for func, sigs in [(register_run_handler, test_signals['run']), - (register_stop_handler, test_signals['stop'])]: + for func, sigs in [ + (register_run_handler, test_signals["run"]), + (register_stop_handler, test_signals["stop"]), + ]: for sig in sigs: func(self.configuration, sig) # Verify handler registration dispatch = signal.getsignal(sig) if func == register_run_handler: - self.assertEqual(dispatch.__name__, 'run_handler') + self.assertEqual(dispatch.__name__, "run_handler") else: - self.assertEqual(dispatch.__name__, 'stop_handler') + self.assertEqual(dispatch.__name__, "stop_handler") def test_event_set_unset_lifecycle(self): """Verify full event lifecycle""" diff --git a/tests/test_mig_lib_events.py b/tests/test_mig_lib_events.py index 368b344d5..b339bdff1 100644 --- a/tests/test_mig_lib_events.py +++ b/tests/test_mig_lib_events.py @@ -32,11 +32,26 @@ import unittest # Imports of the code under test -from mig.lib.events import _restore_env, _save_env, at_remain, cron_match, \ - get_path_expand_map, get_time_expand_map, load_atjobs, load_crontab -from mig.lib.events import parse_and_save_atjobs, parse_and_save_crontab, \ - parse_atjobs, parse_atjobs_contents, parse_crontab, \ - parse_crontab_contents, run_cron_command, run_events_command, legacy_main +from mig.lib.events import ( + _restore_env, + _save_env, + at_remain, + cron_match, + get_path_expand_map, + get_time_expand_map, + legacy_main, + load_atjobs, + load_crontab, + parse_and_save_atjobs, + parse_and_save_crontab, + parse_atjobs, + parse_atjobs_contents, + parse_crontab, + parse_crontab_contents, + run_cron_command, + run_events_command, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist @@ -355,8 +370,7 @@ def test_cron_match_with_wildcards(self): ), ] for job, expected in test_cases: - self.assertEqual(cron_match( - self.configuration, now, job), expected) + self.assertEqual(cron_match(self.configuration, now, job), expected) def test_cron_match_specific_time(self): """Test cron_match rejects non-matching time""" @@ -507,7 +521,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -519,7 +534,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -531,7 +549,8 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -543,7 +562,10 @@ def test_cron_match_with_leading_zero_match(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -643,7 +665,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -655,7 +680,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ( { @@ -667,7 +693,10 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Friday of current month now.replace(day=7).replace( - day=4 + now.replace(day=7).day - now.replace(day=7).weekday()), + day=4 + + now.replace(day=7).day + - now.replace(day=7).weekday() + ), ), ( { @@ -679,7 +708,8 @@ def test_cron_match_with_leading_zero_mismatch(self): }, # Get first Monday of current month now.replace(day=7).replace( - day=now.replace(day=7).day - now.replace(day=7).weekday()), + day=now.replace(day=7).day - now.replace(day=7).weekday() + ), ), ] for job, now in test_cases: @@ -1409,8 +1439,7 @@ def test_get_path_expand_map_with_relative_path(self): trigger_path = "../relative/path/file.txt" rule = {"vgrid_name": "test", "run_as": DUMMY_USER_DN} expanded = get_path_expand_map(trigger_path, rule, "modified") - self.assertEqual(expanded["+TRIGGERPATH+"], - "../relative/path/file.txt") + self.assertEqual(expanded["+TRIGGERPATH+"], "../relative/path/file.txt") self.assertEqual(expanded["+TRIGGERFILENAME+"], "file.txt") self.assertEqual(expanded["+TRIGGERPREFIX+"], "file") self.assertEqual(expanded["+TRIGGEREXTENSION+"], ".txt") @@ -1807,8 +1836,7 @@ def test_parse_and_save_crontab(self): def test_parse_atjobs(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, DUMMY_USER_DN, DUMMY_ATJOBS_CONTENT.splitlines() ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1816,8 +1844,7 @@ def test_parse_atjobs(self): def test_parse_atjobs_contents(self): """Test parsing atjobs content lines""" parsed = parse_atjobs_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_ATJOBS_CONTENT.splitlines() + self.configuration, DUMMY_USER_DN, DUMMY_ATJOBS_CONTENT.splitlines() ) self.assertEqual(len(parsed), 1) self.assertEqual(parsed[0]["command"], ["/bin/future_command"]) @@ -1825,8 +1852,9 @@ def test_parse_atjobs_contents(self): def test_parse_crontab(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -1834,8 +1862,9 @@ def test_parse_crontab(self): def test_parse_crontab_contents(self): """Test parsing crontab content lines""" parsed = parse_crontab_contents( - self.configuration, DUMMY_USER_DN, - DUMMY_CRONTAB_CONTENT.splitlines() + self.configuration, + DUMMY_USER_DN, + DUMMY_CRONTAB_CONTENT.splitlines(), ) self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0]["command"], ["/bin/test_command"]) @@ -3387,6 +3416,7 @@ def before_each(self): def test_existing_main(self): """Wrap existing self-tests""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: @@ -3394,14 +3424,19 @@ def raise_on_error_exit(exit_code): else: identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) raise_on_error_exit.last_print = None def record_last_print(value): raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + ) if __name__ == "__main__": diff --git a/tests/test_mig_lib_janitor.py b/tests/test_mig_lib_janitor.py index 59a522cde..219fe4db0 100644 --- a/tests/test_mig_lib_janitor.py +++ b/tests/test_mig_lib_janitor.py @@ -32,18 +32,36 @@ import time import unittest -from mig.lib.janitor import EXPIRE_DUMMY_JOBS_DAYS, EXPIRE_REQ_DAYS, \ - EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS, MANAGE_TRIVIAL_REQ_MINUTES, \ - REMIND_REQ_DAYS, SECS_PER_DAY, SECS_PER_HOUR, SECS_PER_MINUTE, \ - _clean_stale_state_files, _lookup_last_run, _update_last_run, \ - clean_mig_system_files, clean_no_job_helpers, \ - clean_sessid_to_mrls_link_home, clean_twofactor_sessions, \ - clean_webserver_home, handle_cache_updates, handle_janitor_tasks, \ - handle_pending_requests, handle_session_cleanup, handle_state_cleanup, \ - manage_single_req, manage_trivial_user_requests, \ - remind_and_expire_user_pending, task_triggers +from mig.lib.janitor import ( + EXPIRE_DUMMY_JOBS_DAYS, + EXPIRE_REQ_DAYS, + EXPIRE_STATE_DAYS, + EXPIRE_TWOFACTOR_DAYS, + MANAGE_TRIVIAL_REQ_MINUTES, + REMIND_REQ_DAYS, + SECS_PER_DAY, + SECS_PER_HOUR, + SECS_PER_MINUTE, + _clean_stale_state_files, + _lookup_last_run, + _update_last_run, + clean_mig_system_files, + clean_no_job_helpers, + clean_sessid_to_mrls_link_home, + clean_twofactor_sessions, + clean_webserver_home, + handle_cache_updates, + handle_janitor_tasks, + handle_pending_requests, + handle_session_cleanup, + handle_state_cleanup, + manage_single_req, + manage_trivial_user_requests, + remind_and_expire_user_pending, + task_triggers, +) from mig.shared.accountreq import save_account_request -from mig.shared.base import distinguished_name_to_user, client_id_dir +from mig.shared.base import client_id_dir, distinguished_name_to_user from mig.shared.pwcrypto import generate_reset_token from tests.support import MigTestCase, ensure_dirs_exist @@ -52,25 +70,31 @@ TEST_USER_ORG = "Test Org" TEST_USER_EMAIL = "test@example.com" # TODO: move next to support.usersupp? -TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s' % \ - (TEST_USER_ORG, TEST_USER_FULLNAME, TEST_USER_EMAIL) -TEST_SKIP_EMAIL = '' +TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=%s/OU=NA/CN=%s/emailAddress=%s" % ( + TEST_USER_ORG, + TEST_USER_FULLNAME, + TEST_USER_EMAIL, +) +TEST_SKIP_EMAIL = "" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # TEST_SERVICE = 'dummy-svc' -TEST_AUTH = TEST_SERVICE = 'migoid' -TEST_USERDB = 'MiG-users.db' -TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' +TEST_AUTH = TEST_SERVICE = "migoid" +TEST_USERDB = "MiG-users.db" +TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" # NOTE: these passwords are not and should not ever be used outside unit tests -TEST_MODERN_PW = 'NoSuchPassword_42' -TEST_MODERN_PW_PBKDF2 = \ +TEST_MODERN_PW = "NoSuchPassword_42" +TEST_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$XMZGaar/pU4PvWDr$w0dYjezF6JGtSiYPexyZMt3lM2134uix" -TEST_NEW_MODERN_PW_PBKDF2 = \ +) +TEST_NEW_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n581pjXFHrn" -TEST_INVALID_PW_PBKDF2 = \ +) +TEST_INVALID_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf1rn1pjX0n58FH" +) # NOTE: tokens always should contain a multiple of 4 chars -INVALID_TEST_TOKEN = 'THIS_RESET_TOKEN_WAS_NEVER_VALID' +INVALID_TEST_TOKEN = "THIS_RESET_TOKEN_WAS_NEVER_VALID" class MigLibJanitor(MigTestCase): @@ -78,11 +102,11 @@ class MigLibJanitor(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" - def _prepare_test_file(self, path, times=None, content='test'): + def _prepare_test_file(self, path, times=None, content="test"): """Prepare file in path with optional times for timestamp""" - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(content) os.utime(path, times) @@ -94,8 +118,9 @@ def before_each(self): self.configuration.site_login_methods.append(TEST_AUTH) # Prevent admin email during reject, etc. self.configuration.admin_email = TEST_SKIP_EMAIL - self.user_db_path = os.path.join(self.configuration.user_db_home, - TEST_USERDB) + self.user_db_path = os.path.join( + self.configuration.user_db_home, TEST_USERDB + ) # Create fake fs layout matching real systems ensure_dirs_exist(self.configuration.user_pending) ensure_dirs_exist(self.configuration.user_db_home) @@ -110,8 +135,9 @@ def before_each(self): ensure_dirs_exist(self.configuration.sessid_to_mrsl_link_home) ensure_dirs_exist(self.configuration.mrsl_files_dir) ensure_dirs_exist(self.configuration.resource_pending) - dummy_job = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") + dummy_job = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) ensure_dirs_exist(dummy_job) # Prepare user DB with a single dummy user for all tests @@ -124,20 +150,20 @@ def before_each(self): def test_last_run_bookkeeping(self): """Register a last run timestamp and check it""" expect = -1 - stamp = _lookup_last_run(self.configuration, 'janitor_task') + stamp = _lookup_last_run(self.configuration, "janitor_task") self.assertEqual(stamp, expect) expect = 42 - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) expect = time.time() - stamp = _update_last_run(self.configuration, 'janitor_task', expect) + stamp = _update_last_run(self.configuration, "janitor_task", expect) self.assertEqual(stamp, expect) def test_clean_mig_system_files(self): """Test clean_mig system files helper""" test_time = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - valid_filenames = ['fresh.log', 'current.tmp'] - stale_filenames = ['tmp_expired.txt', 'no_grid_jobs.123'] + valid_filenames = ["fresh.log", "current.tmp"] + stale_filenames = ["tmp_expired.txt", "no_grid_jobs.123"] for name in valid_filenames + stale_filenames: path = os.path.join(self.configuration.mig_system_files, name) self._prepare_test_file(path, (test_time, test_time)) @@ -145,8 +171,10 @@ def test_clean_mig_system_files(self): handled = clean_mig_system_files(self.configuration) self.assertEqual(handled, len(stale_filenames)) - self.assertEqual(len(os.listdir(self.configuration.mig_system_files)), - len(valid_filenames)) + self.assertEqual( + len(os.listdir(self.configuration.mig_system_files)), + len(valid_filenames), + ) for name in valid_filenames: path = os.path.join(self.configuration.mig_system_files, name) self.assertTrue(os.path.exists(path)) @@ -158,8 +186,8 @@ def test_clean_webserver_home(self): """Test clean webserver files helper""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.webserver_home - valid_filename = 'fresh.log' - stale_filename = 'stale.log' + valid_filename = "fresh.log" + stale_filename = "stale.log" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -175,10 +203,11 @@ def test_clean_webserver_home(self): def test_clean_no_job_helpers(self): """Test clean dummy job helper files""" stale_stamp = time.time() - EXPIRE_DUMMY_JOBS_DAYS * SECS_PER_DAY - 1 - test_dir = os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler") - valid_filename = 'alive.txt' - stale_filename = 'expired.txt' + test_dir = os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ) + valid_filename = "alive.txt" + stale_filename = "expired.txt" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -195,8 +224,8 @@ def test_clean_twofactor_sessions(self): """Test clean twofactor sessions""" stale_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.twofactor_home - valid_filename = 'current' - stale_filename = 'expired' + valid_filename = "current" + stale_filename = "expired" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -213,8 +242,8 @@ def test_clean_sessid_to_mrls_link_home(self): """Test clean session MRSL link files""" stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 test_dir = self.configuration.sessid_to_mrsl_link_home - valid_filename = 'active_session_link' - stale_filename = 'expired_session_link' + valid_filename = "active_session_link" + stale_filename = "expired_session_link" valid_path = os.path.join(test_dir, valid_filename) stale_path = os.path.join(test_dir, stale_filename) self._prepare_test_file(valid_path) @@ -232,12 +261,14 @@ def test_handle_state_cleanup(self): # Create a stale file in each location to clean up stale_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 mig_path = os.path.join( - self.configuration.mig_system_files, 'tmpAbCd1234') - web_path = os.path.join(self.configuration.webserver_home, 'stale.txt') + self.configuration.mig_system_files, "tmpAbCd1234" + ) + web_path = os.path.join(self.configuration.webserver_home, "stale.txt") empty_job_path = os.path.join( - os.path.join(self.configuration.user_home, - "no_grid_jobs_in_grid_scheduler"), - 'sleep.job' + os.path.join( + self.configuration.user_home, "no_grid_jobs_in_grid_scheduler" + ), + "sleep.job", ) stale_paths = [mig_path, web_path, empty_job_path] for path in stale_paths: @@ -252,12 +283,17 @@ def test_handle_state_cleanup(self): def test_handle_session_cleanup(self): """Test combined session cleanup""" - stale_stamp = time.time() - max(EXPIRE_STATE_DAYS, - EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY - 1 + stale_stamp = ( + time.time() + - max(EXPIRE_STATE_DAYS, EXPIRE_TWOFACTOR_DAYS) * SECS_PER_DAY + - 1 + ) session_path = os.path.join( - self.configuration.sessid_to_mrsl_link_home, 'expired.txt') + self.configuration.sessid_to_mrsl_link_home, "expired.txt" + ) twofactor_path = os.path.join( - self.configuration.twofactor_home, 'expired.txt') + self.configuration.twofactor_home, "expired.txt" + ) test_paths = [session_path, twofactor_path] for path in test_paths: os.makedirs(os.path.dirname(path), exist_ok=True) @@ -271,17 +307,17 @@ def test_handle_session_cleanup(self): def test_manage_pending_user_request(self): """Test pending user request management""" - req_id = 'req_id' + req_id = "req_id" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) @@ -293,24 +329,25 @@ def test_manage_pending_user_request(self): os.utime(req_path, (req_age, req_age)) # Need user DB and path to simulate existing user - user_dir = os.path.join(self.configuration.user_home, - client_id_dir(TEST_USER_DN)) + user_dir = os.path.join( + self.configuration.user_home, client_id_dir(TEST_USER_DN) + ) os.makedirs(user_dir, exist_ok=True) handled = manage_trivial_user_requests(self.configuration) self.assertEqual(handled, 1) def test_expire_user_pending(self): """Test pending user request expiration reminders""" - req_id = 'expired_req' + req_id = "expired_req" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -333,42 +370,46 @@ def test_handle_pending_requests(self): """Test combined request handling""" # Create requests (valid, expired) valid_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) - saved, valid_req_path = save_account_request(self.configuration, - valid_dict) + saved, valid_req_path = save_account_request( + self.configuration, valid_dict + ) self.assertTrue(saved, "failed to save valid req") self.assertDirNotEmpty(self.configuration.user_pending) valid_id = os.path.basename(valid_req_path) - expired_id = 'expired_req' + expired_id = "expired_req" expired_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } saved, expired_req_path = save_account_request( - self.configuration, expired_dict) + self.configuration, expired_dict + ) self.assertTrue(saved, "failed to save expired req") expired_id = os.path.basename(expired_req_path) # Make just one old enough to expire expire_time = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, expired_id), - (expire_time, expire_time)) + os.utime( + os.path.join(self.configuration.user_pending, expired_id), + (expire_time, expire_time), + ) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. @@ -383,25 +424,29 @@ def test_handle_janitor_tasks_full(self): """Test full janitor task scheduler""" # Prepare environment with pending tasks of each kind mig_stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 - mig_path = os.path.join(self.configuration.mig_system_files, - 'tmp-stale.txt') - two_path = os.path.join(self.configuration.twofactor_home, 'stale.txt') + mig_path = os.path.join( + self.configuration.mig_system_files, "tmp-stale.txt" + ) + two_path = os.path.join(self.configuration.twofactor_home, "stale.txt") two_stamp = time.time() - EXPIRE_TWOFACTOR_DAYS * SECS_PER_DAY - 1 - stale_tests = ((mig_path, mig_stamp), (two_path, two_stamp), ) - for (stale_path, stale_stamp) in stale_tests: + stale_tests = ( + (mig_path, mig_stamp), + (two_path, two_stamp), + ) + for stale_path, stale_stamp in stale_tests: self._prepare_test_file(stale_path, (stale_stamp, stale_stamp)) self.assertTrue(os.path.exists(stale_path)) - req_id = 'expired_request' + req_id = "expired_request" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'peers': [TEST_PEER_DN], - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "peers": [TEST_PEER_DN], + "email": TEST_USER_EMAIL, } self.assertDirEmpty(self.configuration.user_pending) saved, req_path = save_account_request(self.configuration, req_dict) @@ -410,8 +455,10 @@ def test_handle_janitor_tasks_full(self): req_id = os.path.basename(req_path) # Make request very old req_age = time.time() - EXPIRE_REQ_DAYS * SECS_PER_DAY - 1 - os.utime(os.path.join(self.configuration.user_pending, req_id), - (req_age, req_age)) + os.utime( + os.path.join(self.configuration.user_pending, req_id), + (req_age, req_age), + ) # Set no last run timestamps to trigger all tasks now = time.time() @@ -426,26 +473,26 @@ def test_handle_janitor_tasks_full(self): handled = handle_janitor_tasks(self.configuration, now=now) # self.assertEqual(handled, 3) # state+session+requests self.assertEqual(handled, 5) # state+session+3*request - for (stale_path, _) in stale_tests: + for stale_path, _ in stale_tests: self.assertFalse(os.path.exists(stale_path), stale_path) def test__clean_stale_state_files(self): """Test core stale state file cleaner helper""" - test_dir = self.temppath('stale_state_test', ensure_dir=True) - patterns = ['tmp_*', 'session_*'] + test_dir = self.temppath("stale_state_test", ensure_dir=True) + patterns = ["tmp_*", "session_*"] # Create test files (fresh, expired, unexpired, non-matching) test_remove = [ - ('tmp_expired.txt', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), - ('session_old.dat', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_expired.txt", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("session_old.dat", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] test_keep = [ - ('tmp_fresh.txt', -1), - ('session_valid.dat', 0), - ('other_file.log', EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), + ("tmp_fresh.txt", -1), + ("session_valid.dat", 0), + ("other_file.log", EXPIRE_STATE_DAYS * SECS_PER_DAY + 1), ] - for (name, age_diff) in test_keep + test_remove: + for name, age_diff in test_keep + test_remove: path = os.path.join(test_dir, name) stamp = time.time() - age_diff self._prepare_test_file(path, (stamp, stamp)) @@ -457,27 +504,27 @@ def test__clean_stale_state_files(self): patterns, EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 2) # tmp_expired.txt + session_old.dat - for (name, _) in test_keep: + for name, _ in test_keep: path = os.path.join(test_dir, name) self.assertTrue(os.path.exists(path)) - for (name, _) in test_remove: + for name, _ in test_remove: path = os.path.join(test_dir, name) self.assertFalse(os.path.exists(path)) def test_manage_single_req_invalid(self): """Test request handling for invalid request""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Missing required field: organization'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Missing required field: organization"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -486,17 +533,18 @@ def test_manage_single_req_invalid(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('invalid account request' in msg - for msg in log_capture.output)) + self.assertTrue( + any("invalid account request" in msg for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean invalid req for %s" % req_path) @@ -504,24 +552,24 @@ def test_manage_single_req_invalid(self): def test_manage_single_req_expired_token(self): """Test request handling with expired reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Mimic proper but old expired token timestamp = 42 # IMPORTANT: we can't use a fixed token here due to dynamic crypto seed - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -529,17 +577,21 @@ def test_manage_single_req_expired_token(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reject expired reset token' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "reject expired reset token" in msg + for msg in log_capture.output + ) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed to clean token req for %s" % req_path) @@ -548,52 +600,57 @@ def test_manage_single_req_expired_token(self): def test_manage_single_req_invalid_token(self): """Test request handling with invalid reset token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() - SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() - SECS_PER_DAY, } # Inject known invalid reset token - req_dict['reset_token'] = INVALID_TEST_TOKEN + req_dict["reset_token"] = INVALID_TEST_TOKEN # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('reset with bad token' in msg - for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue( + any("reset with bad token" in msg for msg in log_capture.output) + ) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_manage_single_req_collision(self): """Test request handling with existing user collision""" # Create collision with the already provisioned user with TEST_USER_DN changed_full_name = "Changed Test Name" req_dict = { - 'client_id': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'distinguished_name': TEST_USER_DN.replace(TEST_USER_FULLNAME, - changed_full_name), - 'auth': [TEST_AUTH], - 'full_name': changed_full_name, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "distinguished_name": TEST_USER_DN.replace( + TEST_USER_FULLNAME, changed_full_name + ), + "auth": [TEST_AUTH], + "full_name": changed_full_name, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to cause collision - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -602,16 +659,17 @@ def test_manage_single_req_collision(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), + ) + self.assertTrue( + any("ID collision" in msg for msg in log_capture.output) ) - self.assertTrue(any('ID collision' in msg - for msg in log_capture.output)) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup collision for %s" % req_path) @@ -619,20 +677,20 @@ def test_manage_single_req_collision(self): def test_manage_single_req_auth_change(self): """Test request handling with auth password change""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, } # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 - req_dict['authorized'] = True + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["authorized"] = True saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -640,19 +698,20 @@ def test_manage_single_req_auth_change(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue( - any('accepted' in msg for msg in log_capture.output)) - self.assertFalse(os.path.exists(req_path), - "Failed to clean token req for %s" % req_path) + self.assertTrue(any("accepted" in msg for msg in log_capture.output)) + self.assertFalse( + os.path.exists(req_path), + "Failed to clean token req for %s" % req_path, + ) def test_handle_cache_updates_stub(self): """Test handle_cache_updates placeholder returns zero""" @@ -662,7 +721,7 @@ def test_handle_cache_updates_stub(self): def test_janitor_update_timestamps(self): """Test task trigger timestamp updates in janitor""" now = time.time() - task = 'test-task' + task = "test-task" # Initial state stamp = _lookup_last_run(self.configuration, task) @@ -678,24 +737,24 @@ def test_janitor_update_timestamps(self): def test__clean_stale_state_files_edge(self): """Test state file cleaner with special cases""" - test_dir = self.temppath('edge_case_test', ensure_dir=True) + test_dir = self.temppath("edge_case_test", ensure_dir=True) # Dot file - dot_path = os.path.join(test_dir, '.hidden.tmp') + dot_path = os.path.join(test_dir, ".hidden.tmp") stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(dot_path, (stamp, stamp)) # Directory - dir_path = os.path.join(test_dir, 'subdir') + dir_path = os.path.join(test_dir, "subdir") os.makedirs(dir_path) handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=False + include_dotfiles=False, ) self.assertEqual(handled, 0) @@ -703,46 +762,50 @@ def test__clean_stale_state_files_edge(self): handled = _clean_stale_state_files( self.configuration, test_dir, - ['*'], + ["*"], EXPIRE_STATE_DAYS, time.time(), - include_dotfiles=True + include_dotfiles=True, ) self.assertEqual(handled, 1) @unittest.skip("TODO: enable once unpickling error handling is improved") def test_manage_single_req_corrupted_file(self): """Test manage_single_req with corrupted request file""" - req_id = 'corrupted_req' + req_id = "corrupted_req" req_path = os.path.join(self.configuration.user_pending, req_id) - with open(req_path, 'w') as fp: - fp.write('invalid pickle content') + with open(req_path, "w") as fp: + fp.write("invalid pickle content") - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load request from' in msg - or 'Could not load saved request' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "Failed to load request from" in msg + or "Could not load saved request" in msg + for msg in log_capture.output + ) + ) self.assertFalse(os.path.exists(req_path)) def test_manage_single_req_nonexistent_userdb(self): """Test manage_single_req with missing user database""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password_hash': TEST_MODERN_PW_PBKDF2, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password_hash": TEST_MODERN_PW_PBKDF2, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -750,38 +813,39 @@ def test_manage_single_req_nonexistent_userdb(self): # Remove user database os.remove(self.user_db_path) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('Failed to load user DB' in msg - for msg in log_capture.output)) + self.assertTrue( + any("Failed to load user DB" in msg for msg in log_capture.output) + ) def test_verify_reset_token_failure_logging(self): """Test token verification failure creates proper log entries""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "email": TEST_USER_EMAIL, + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() # Now change to another pw hash and generate invalid token from it - req_dict['password_hash'] = TEST_INVALID_PW_PBKDF2 - req_dict['reset_token'] = generate_reset_token(self.configuration, - req_dict, TEST_SERVICE, - timestamp) + req_dict["password_hash"] = TEST_INVALID_PW_PBKDF2 + req_dict["reset_token"] = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -790,17 +854,18 @@ def test_verify_reset_token_failure_logging(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='WARNING') as log_capture: + with self.assertLogs(level="WARNING") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('wrong hash' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("wrong hash" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -808,23 +873,24 @@ def test_verify_reset_token_failure_logging(self): def test_verify_reset_token_success(self): """Test token verification success with valid token""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'email': TEST_USER_EMAIL, - 'password': '', - 'password_hash': TEST_MODERN_PW_PBKDF2, - 'expire': time.time() + SECS_PER_DAY, # Future expiration + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "email": TEST_USER_EMAIL, + "password": "", + "password_hash": TEST_MODERN_PW_PBKDF2, + "expire": time.time() + SECS_PER_DAY, # Future expiration } timestamp = time.time() - reset_token = generate_reset_token(self.configuration, req_dict, - TEST_SERVICE, timestamp) - req_dict['reset_token'] = reset_token + reset_token = generate_reset_token( + self.configuration, req_dict, TEST_SERVICE, timestamp + ) + req_dict["reset_token"] = reset_token # Change password_hash here to mimic pw change - req_dict['password_hash'] = TEST_NEW_MODERN_PW_PBKDF2 + req_dict["password_hash"] = TEST_NEW_MODERN_PW_PBKDF2 saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -832,17 +898,18 @@ def test_verify_reset_token_success(self): # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - with self.assertLogs(level='INFO') as log_capture: + with self.assertLogs(level="INFO") as log_capture: manage_single_req( self.configuration, req_id, req_path, self.user_db_path, - time.time() + time.time(), ) - self.assertTrue(any('accepted' in msg.lower() - for msg in log_capture.output)) + self.assertTrue( + any("accepted" in msg.lower() for msg in log_capture.output) + ) # TODO: enable check for removed req once skip email allows it # self.assertFalse(os.path.exists(req_path), # "Failed cleanup invalid token for %s" % req_path) @@ -851,23 +918,22 @@ def test_remind_and_expire_edge_cases(self): """Test request expiration with exact boundary timestamps""" now = time.time() test_cases = [ - ('exact_remind', now - REMIND_REQ_DAYS * SECS_PER_DAY), - ('exact_expire', now - EXPIRE_REQ_DAYS * SECS_PER_DAY), + ("exact_remind", now - REMIND_REQ_DAYS * SECS_PER_DAY), + ("exact_expire", now - EXPIRE_REQ_DAYS * SECS_PER_DAY), ] - for (req_id, mtime) in test_cases: + for req_id, mtime in test_cases: req_path = os.path.join(self.configuration.user_pending, req_id) req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, - 'password': TEST_MODERN_PW, - 'email': TEST_USER_EMAIL, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, + "password": TEST_MODERN_PW, + "email": TEST_USER_EMAIL, } - saved, req_path = save_account_request( - self.configuration, req_dict) + saved, req_path = save_account_request(self.configuration, req_dict) os.utime(req_path, (mtime, mtime)) # NOTE: when using real user mail we currently hit send email errors. @@ -884,29 +950,43 @@ def test_handle_janitor_tasks_time_thresholds(self): """Test janitor task frequency thresholds""" now = time.time() - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), -1) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), -1) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), -1 + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), -1 + ) # Test all tasks EXCEPT cache-updates are past threshold last_state_cleanup = now - SECS_PER_DAY - 3 last_session_cleanup = now - SECS_PER_HOUR - 3 last_pending_reqs = now - SECS_PER_MINUTE - 3 last_cache_update = now - SECS_PER_MINUTE + 10 # Not expired - task_triggers.update({'state-cleanup': last_state_cleanup, - 'session-cleanup': last_session_cleanup, - 'pending-reqs': last_pending_reqs, - 'cache-updates': last_cache_update}) - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), last_state_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), last_session_cleanup) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + task_triggers.update( + { + "state-cleanup": last_state_cleanup, + "session-cleanup": last_session_cleanup, + "pending-reqs": last_pending_reqs, + "cache-updates": last_cache_update, + } + ) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), + last_state_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), + last_session_cleanup, + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) # TODO: handled does NOT count no action runs - add dummies to handle? handled = handle_janitor_tasks(self.configuration, now=now) @@ -914,26 +994,32 @@ def test_handle_janitor_tasks_time_thresholds(self): self.assertEqual(handled, 0) # ran with nothing to do # Verify last run timestamps updated - self.assertEqual(_lookup_last_run( - self.configuration, "state-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "session-cleanup"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "pending-reqs"), now) - self.assertEqual(_lookup_last_run( - self.configuration, "cache-updates"), last_cache_update) + self.assertEqual( + _lookup_last_run(self.configuration, "state-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "session-cleanup"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "pending-reqs"), now + ) + self.assertEqual( + _lookup_last_run(self.configuration, "cache-updates"), + last_cache_update, + ) @unittest.skip("TODO: enable once cleaner has improved error handling") def test_clean_stale_files_nonexistent_dir(self): """Test state cleaner with invalid directory path""" - target_dir = os.path.join(self.configuration.mig_system_files, - "non_existing_dir") + target_dir = os.path.join( + self.configuration.mig_system_files, "non_existing_dir" + ) handled = _clean_stale_state_files( self.configuration, target_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -947,13 +1033,13 @@ def test_clean_stale_files_permission_error(self): stamp = time.time() - EXPIRE_STATE_DAYS * SECS_PER_DAY - 1 self._prepare_test_file(test_path, (stamp, stamp)) - with self.assertLogs(level='ERROR'): + with self.assertLogs(level="ERROR"): handled = _clean_stale_state_files( self.configuration, test_dir, ["*"], EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 0) @@ -976,14 +1062,14 @@ def test_handle_empty_pending_dir(self): def test_janitor_task_cleanup_after_reject(self): """Verify proper cleanup after request rejection""" req_dict = { - 'client_id': TEST_USER_DN, - 'distinguished_name': TEST_USER_DN, - 'invalid': ['Test intentional invalid'], - 'auth': [TEST_AUTH], - 'full_name': TEST_USER_FULLNAME, - 'organization': TEST_USER_ORG, + "client_id": TEST_USER_DN, + "distinguished_name": TEST_USER_DN, + "invalid": ["Test intentional invalid"], + "auth": [TEST_AUTH], + "full_name": TEST_USER_FULLNAME, + "organization": TEST_USER_ORG, # NOTE: we need original email here to match provisioned user - 'email': TEST_USER_EMAIL, + "email": TEST_USER_EMAIL, } saved, req_path = save_account_request(self.configuration, req_dict) req_id = os.path.basename(req_path) @@ -996,11 +1082,7 @@ def test_janitor_task_cleanup_after_reject(self): # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() manage_single_req( - self.configuration, - req_id, - req_path, - self.user_db_path, - time.time() + self.configuration, req_id, req_path, self.user_db_path, time.time() ) # TODO: enable check for removed req once skip email allows it @@ -1009,15 +1091,15 @@ def test_janitor_task_cleanup_after_reject(self): def test_cleaner_with_multiple_patterns(self): """Test state cleaner with multiple filename patterns""" - test_dir = self.temppath('multi_pattern_test', ensure_dir=True) - clean_patterns = ['*.tmp', '*.log', 'temp*'] + test_dir = self.temppath("multi_pattern_test", ensure_dir=True) + clean_patterns = ["*.tmp", "*.log", "temp*"] clean_pairs = [ - ('should_keep_recent.log', EXPIRE_STATE_DAYS - 1), - ('should_remove_stale.tmp', EXPIRE_STATE_DAYS + 1), - ('should_keep_other.pck', EXPIRE_STATE_DAYS + 1) + ("should_keep_recent.log", EXPIRE_STATE_DAYS - 1), + ("should_remove_stale.tmp", EXPIRE_STATE_DAYS + 1), + ("should_keep_other.pck", EXPIRE_STATE_DAYS + 1), ] - for (name, age_days) in clean_pairs: + for name, age_days in clean_pairs: path = os.path.join(test_dir, name) stamp = time.time() - age_days * SECS_PER_DAY self._prepare_test_file(path, (stamp, stamp)) @@ -1028,15 +1110,18 @@ def test_cleaner_with_multiple_patterns(self): test_dir, clean_patterns, EXPIRE_STATE_DAYS, - time.time() + time.time(), ) self.assertEqual(handled, 1) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_recent.log'))) - self.assertFalse(os.path.exists( - os.path.join(test_dir, 'should_remove_stale.tmp'))) - self.assertTrue(os.path.exists( - os.path.join(test_dir, 'should_keep_other.pck'))) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_recent.log")) + ) + self.assertFalse( + os.path.exists(os.path.join(test_dir, "should_remove_stale.tmp")) + ) + self.assertTrue( + os.path.exists(os.path.join(test_dir, "should_keep_other.pck")) + ) def test_absent_jobs_flag(self): """Test clean_no_job_helpers with site_enable_jobs disabled""" diff --git a/tests/test_mig_lib_quota.py b/tests/test_mig_lib_quota.py index e4057b5d1..c99b0a920 100644 --- a/tests/test_mig_lib_quota.py +++ b/tests/test_mig_lib_quota.py @@ -29,6 +29,7 @@ # Imports of the code under test from mig.lib.quota import update_quota + # Imports required for the unit tests themselves from tests.support import MigTestCase @@ -38,7 +39,7 @@ class MigLibQouta(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" def before_each(self): """Set up test configuration and reset state before each test""" @@ -47,7 +48,9 @@ def before_each(self): def test_invalid_quota_backend(self): """Test invalid quota_backend in configuration""" self.configuration.quota_backend = "NEVERNEVER" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: update_quota(self.configuration) - self.assertTrue("'NEVERNEVER' not in supported_quota_backends:" in msg - for msg in log_capture.output) + self.assertTrue( + "'NEVERNEVER' not in supported_quota_backends:" in msg + for msg in log_capture.output + ) diff --git a/tests/test_mig_lib_xgicore.py b/tests/test_mig_lib_xgicore.py index e63e22b7d..588b487bd 100644 --- a/tests/test_mig_lib_xgicore.py +++ b/tests/test_mig_lib_xgicore.py @@ -30,9 +30,8 @@ import os import sys -from tests.support import MigTestCase, FakeConfiguration, testmain - from mig.lib.xgicore import * +from tests.support import FakeConfiguration, MigTestCase, testmain class MigLibXgicore__get_output_format(MigTestCase): @@ -42,28 +41,32 @@ def test_default_when_missing(self): """Test that default output_format is returned when not set.""" expected = "html" user_args = {} - actual = get_output_format(FakeConfiguration(), user_args, - default_format=expected) - self.assertEqual(actual, expected, - "mismatch in default output_format") + actual = get_output_format( + FakeConfiguration(), user_args, default_format=expected + ) + self.assertEqual(actual, expected, "mismatch in default output_format") def test_get_single_requested_format(self): """Test that the requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected]} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__override_output_format(MigTestCase): @@ -74,29 +77,35 @@ def test_unchanged_without_override(self): expected = "html" user_args = {} out_objs = [] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, expected) - self.assertEqual(actual, expected, - "mismatch in unchanged output_format") + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, expected + ) + self.assertEqual( + actual, expected, "mismatch in unchanged output_format" + ) def test_get_single_requested_format(self): """Test that the requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected]} - out_objs = [{'object_type': 'start', 'override_format': True}] - actual = override_output_format(FakeConfiguration(), user_args, - out_objs, 'OVERRIDE') - self.assertEqual(actual, expected, - "mismatch in overriden output_format") + user_args = {"output_format": [expected]} + out_objs = [{"object_type": "start", "override_format": True}] + actual = override_output_format( + FakeConfiguration(), user_args, out_objs, "OVERRIDE" + ) + self.assertEqual( + actual, expected, "mismatch in overriden output_format" + ) def test_get_first_requested_format(self): """Test that first requested output_format is returned if overriden.""" expected = "file" - user_args = {'output_format': [expected, 'BOGUS']} - actual = get_output_format(FakeConfiguration(), user_args, - default_format='BOGUS') - self.assertEqual(actual, expected, - "mismatch in extracted output_format") + user_args = {"output_format": [expected, "BOGUS"]} + actual = get_output_format( + FakeConfiguration(), user_args, default_format="BOGUS" + ) + self.assertEqual( + actual, expected, "mismatch in extracted output_format" + ) class MigLibXgicore__fill_start_headers(MigTestCase): @@ -105,35 +114,38 @@ class MigLibXgicore__fill_start_headers(MigTestCase): def test_unchanged_when_set(self): """Test that existing valid start entry is returned as-is.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream'), - ('Content-Size', 42)] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [expected, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [ + ("Content-Type", "application/octet-stream"), + ("Content-Size", 42), + ] + expected = {"object_type": "start", "headers": headers} + out_objs = [expected, {"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in unchanged start entry") + self.assertEqual(actual, expected, "mismatch in unchanged start entry") def test_headers_added_when_missing(self): """Test that start entry headers are added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - minimal_start = {'object_type': 'start'} - expected = {'object_type': 'start', 'headers': headers} - out_objs = [minimal_start, {'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + minimal_start = {"object_type": "start"} + expected = {"object_type": "start", "headers": headers} + out_objs = [minimal_start, {"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) def test_start_added_when_missing(self): """Test that start entry is added if missing.""" out_format = "file" - headers = [('Content-Type', 'application/octet-stream')] - expected = {'object_type': 'start', 'headers': headers} - out_objs = [{'object_type': 'binary', 'data': 42*b'0'}] + headers = [("Content-Type", "application/octet-stream")] + expected = {"object_type": "start", "headers": headers} + out_objs = [{"object_type": "binary", "data": 42 * b"0"}] actual = fill_start_headers(FakeConfiguration(), out_objs, out_format) - self.assertEqual(actual, expected, - "mismatch in auto initialized start entry") + self.assertEqual( + actual, expected, "mismatch in auto initialized start entry" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_accountreq.py b/tests/test_mig_shared_accountreq.py index ecba42ff3..0b7a6919d 100644 --- a/tests/test_mig_shared_accountreq.py +++ b/tests/test_mig_shared_accountreq.py @@ -35,9 +35,11 @@ # Imports of the code under test import mig.shared.accountreq as accountreq + # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user, fill_distinguished_name from mig.shared.defaults import keyword_auto + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain from tests.support.fixturesupp import FixtureAssertMixin @@ -47,8 +49,8 @@ class MigSharedAccountreq__peers(MigTestCase, FixtureAssertMixin): """Unit tests for peers related functions within the accountreq module""" - TEST_PEER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_PEER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=peer@example.com" + TEST_USER_DN = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" @property def user_settings_dir(self): @@ -60,40 +62,46 @@ def user_pending_dir(self): def _load_saved_peer(self, absolute_path): self.assertPathWithin(absolute_path, start=self.user_pending_dir) - with open(absolute_path, 'rb') as pickle_file: + with open(absolute_path, "rb") as pickle_file: value = pickle.load(pickle_file) def _string_if_bytes(value): if isinstance(value, bytes): - return str(value, 'utf8') + return str(value, "utf8") else: return value - return {_string_if_bytes(x): _string_if_bytes(y) - for x, y in value.items()} + + return { + _string_if_bytes(x): _string_if_bytes(y) for x, y in value.items() + } def _peer_dict_from_fixture(self): - prepared_fixture = self.prepareFixtureAssert("peer_user_dict", - fixture_format="json") + prepared_fixture = self.prepareFixtureAssert( + "peer_user_dict", fixture_format="json" + ) fixture_data = prepared_fixture.fixture_data assert fixture_data["distinguished_name"] == self.TEST_PEER_DN return fixture_data - def _record_peer_acceptance(self, test_client_dir_name, - peer_distinguished_name): - """Fabricate a peer acceptance record in a particular user settings dir. - """ + def _record_peer_acceptance( + self, test_client_dir_name, peer_distinguished_name + ): + """Fabricate a peer acceptance record in a particular user settings dir.""" test_user_accepted_peers_file = os.path.join( - self.user_settings_dir, test_client_dir_name, "peers") + self.user_settings_dir, test_client_dir_name, "peers" + ) expire_tomorrow = datetime.date.today() + datetime.timedelta(days=1) - with open(test_user_accepted_peers_file, "wb") as \ - test_user_accepted_peers: - pickle.dump({peer_distinguished_name: - {'expire': str(expire_tomorrow)}}, - test_user_accepted_peers) + with open( + test_user_accepted_peers_file, "wb" + ) as test_user_accepted_peers: + pickle.dump( + {peer_distinguished_name: {"expire": str(expire_tomorrow)}}, + test_user_accepted_peers, + ) def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): ensure_dirs_exist(self.configuration.user_cache) @@ -108,8 +116,9 @@ def test_a_new_peer(self): self.assertDirEmpty(self.configuration.user_pending) request_dict = self._peer_dict_from_fixture() - success, _ = accountreq.save_account_request(self.configuration, - request_dict) + success, _ = accountreq.save_account_request( + self.configuration, request_dict + ) # check that we have an output directory now absolute_files = self.assertDirNotEmpty(self.user_pending_dir) @@ -131,10 +140,11 @@ def test_listing_peers(self): # check the fabricated peer was listed # sadly listing returns _relative_ dirs peer_temp_file_name = listing[0] - peer_pickle_file = os.path.join(self.user_pending_dir, - peer_temp_file_name) + peer_pickle_file = os.path.join( + self.user_pending_dir, peer_temp_file_name + ) peer_pickle = self._load_saved_peer(peer_pickle_file) - self.assertEqual(peer_pickle['distinguished_name'], self.TEST_PEER_DN) + self.assertEqual(peer_pickle["distinguished_name"], self.TEST_PEER_DN) def test_peer_acceptance(self): test_client_dir = self._provision_test_user(self, self.TEST_USER_DN) @@ -142,17 +152,18 @@ def test_peer_acceptance(self): self._record_peer_acceptance(test_client_dir_name, self.TEST_PEER_DN) self.assertDirEmpty(self.user_pending_dir) request_dict = self._peer_dict_from_fixture() - success, req_path = accountreq.save_account_request(self.configuration, - request_dict) + success, req_path = accountreq.save_account_request( + self.configuration, request_dict + ) arranged_req_id = os.path.basename(req_path) # NOTE: when using real user mail we currently hit send email errors. # We forgive those errors here and only check any known warnings. # TODO: integrate generic skip email support and adjust here to fit self.logger.forgive_errors() - success, message = accountreq.accept_account_req(arranged_req_id, - self.configuration, - keyword_auto) + success, message = accountreq.accept_account_req( + arranged_req_id, self.configuration, keyword_auto + ) self.assertTrue(success) @@ -160,29 +171,44 @@ def test_peer_acceptance(self): class MigSharedAccountreq__filters(MigTestCase, UserAssertMixin): """Unit tests for filter related functions within the accountreq module""" - TEST_SERVICE = 'migoid' - TEST_INTERNAL_DN = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=Test Name/emailAddress=test@local.org' - TEST_EXTERNAL_DN = '/C=DK/ST=NA/L=NA/O=External Org/OU=NA/CN=Test User/emailAddress=test@external.org' - TEST_USER_DN = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' - TEST_ADMIN_DN = '/C=DK/ST=NA/L=NA/O=DIKU/OU=NA/CN=Test Admin/emailAddress=siteadm@di.ku.dk' - - TEST_INT_PW = 'PW74deb6609F109f504d' - TEST_EXT_PW = 'PW174db6509F109e1531' - TEST_USER_PW = 'foobar' - TEST_INT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy' - TEST_EXT_PW_HASH = 'PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$TQZFnCp7hmIGZ6ep2rEg/HYrzHnVyiib' - TEST_USER_PW_HASH = 'PBKDF2$sha256$10000$/TkhLk4yMGf6XhaY$7HUeQ9iwCkE4YMQAaCd+ZdrN+y8EzkJH' - - TEST_INTERNAL_EMAILS = ['john.doe@science.ku.dk', 'abc123@ku.dk', - 'john.doe@a.b.c.ku.dk'] - TEST_EXTERNAL_EMAILS = ['john@doe.org', 'a@b.c.org', 'a@ku.dk.com', - 'a@sci.ku.dk.org', 'a@diku.dk', 'a@nbi.dk'] - TEST_EXTERNAL_EMAIL_PATTERN = r'^.+(?@') + "connection_string": "user:password@db.example.com", + "other_field": "some_value", } + subst_map = {"connection_string": (r":.*@", r":@")} masked = mask_creds(user_dict, subst_map=subst_map) - self.assertEqual(masked['connection_string'], - 'user:@db.example.com') - self.assertEqual(masked['other_field'], 'some_value') + self.assertEqual( + masked["connection_string"], "user:@db.example.com" + ) + self.assertEqual(masked["other_field"], "some_value") def test_mask_creds_no_maskable_fields(self): """Test mask_creds with a dictionary containing no maskable fields.""" - user_dict = {'username': 'test', 'role': 'user'} + user_dict = {"username": "test", "role": "user"} masked = mask_creds(user_dict) self.assertEqual(user_dict, masked) @@ -230,55 +275,58 @@ def test_mask_creds_empty_dict(self): def test_mask_creds_csrf_field(self): """Test that the default csrf_field is masked.""" - user_dict = {csrf_field: 'some_csrf_token', 'other': 'value'} + user_dict = {csrf_field: "some_csrf_token", "other": "value"} masked = mask_creds(user_dict) - self.assertEqual(masked[csrf_field], '**HIDDEN**') - self.assertEqual(masked['other'], 'value') + self.assertEqual(masked[csrf_field], "**HIDDEN**") + self.assertEqual(masked["other"], "value") def test_extract_field_exists(self): """Test extracting an existing field from a distinguished name.""" - self.assertEqual(extract_field(TEST_USER_ID, 'full_name'), - TEST_FULL_NAME) - self.assertEqual(extract_field(TEST_USER_ID, 'organization'), - TEST_ORGANIZATION) - self.assertEqual(extract_field(TEST_USER_ID, 'country'), TEST_COUNTRY) - self.assertEqual(extract_field(TEST_USER_ID, 'email'), TEST_EMAIL) + self.assertEqual( + extract_field(TEST_USER_ID, "full_name"), TEST_FULL_NAME + ) + self.assertEqual( + extract_field(TEST_USER_ID, "organization"), TEST_ORGANIZATION + ) + self.assertEqual(extract_field(TEST_USER_ID, "country"), TEST_COUNTRY) + self.assertEqual(extract_field(TEST_USER_ID, "email"), TEST_EMAIL) def test_extract_field_not_exists(self): """Test extracting a non-existent field returns None.""" - self.assertIsNone(extract_field(TEST_USER_ID, 'missing')) - self.assertIsNone(extract_field(TEST_USER_ID, 'dummy')) + self.assertIsNone(extract_field(TEST_USER_ID, "missing")) + self.assertIsNone(extract_field(TEST_USER_ID, "dummy")) def test_extract_field_with_na_value(self): """Test extracting a field with 'NA' value, which should be an empty string.""" - self.assertEqual(extract_field('/C=DK/DUMMY=NA/CN=TEST', 'DUMMY'), '') + self.assertEqual(extract_field("/C=DK/DUMMY=NA/CN=TEST", "DUMMY"), "") def test_extract_field_custom_field(self): """Test extracting a custom (non-standard) field.""" - self.assertEqual(extract_field('/C=DK/DUMMY=proj1/CN=Test', 'DUMMY'), - 'proj1') + self.assertEqual( + extract_field("/C=DK/DUMMY=proj1/CN=Test", "DUMMY"), "proj1" + ) def test_extract_field_empty_dn(self): """Test extracting from an empty distinguished name.""" - self.assertIsNone(extract_field("", 'full_name')) + self.assertIsNone(extract_field("", "full_name")) def test_extract_field_malformed_dn(self): """Test extracting from a malformed distinguished name.""" dn_empty_val = "/C=US/O=/CN=John Doe" - self.assertEqual(extract_field(dn_empty_val, 'organization'), '') + self.assertEqual(extract_field(dn_empty_val, "organization"), "") dn_no_equals = "/C=US/O/CN=John Doe" - self.assertIsNone(extract_field(dn_no_equals, 'organization')) + self.assertIsNone(extract_field(dn_no_equals, "organization")) def test_distinguished_name_to_user_basic(self): """Test basic conversion from distinguished name to user dictionary.""" user_dict = distinguished_name_to_user(TEST_USER_ID) expected = { - 'distinguished_name': TEST_USER_ID, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, + "distinguished_name": TEST_USER_ID, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, } self.assertEqual(user_dict, expected) @@ -287,12 +335,12 @@ def test_distinguished_name_to_user_with_na(self): dn = "%s/dummy=NA" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': '' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "", } self.assertEqual(user_dict, expected) @@ -301,28 +349,29 @@ def test_distinguished_name_to_user_with_custom_field(self): dn = "%s/dummy=proj1" % TEST_USER_ID user_dict = distinguished_name_to_user(dn) expected = { - 'distinguished_name': dn, - 'country': TEST_COUNTRY, - 'organization': TEST_ORGANIZATION, - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'dummy': 'proj1' + "distinguished_name": dn, + "country": TEST_COUNTRY, + "organization": TEST_ORGANIZATION, + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "dummy": "proj1", } self.assertEqual(user_dict, expected) def test_distinguished_name_to_user_empty_and_malformed(self): """Test behavior with empty and malformed distinguished names.""" # Empty DN - self.assertEqual(distinguished_name_to_user(""), - {'distinguished_name': ''}) + self.assertEqual( + distinguished_name_to_user(""), {"distinguished_name": ""} + ) # Malformed part (no '=') dn_malformed = "/C=US/O/CN=John Doe" user_dict_malformed = distinguished_name_to_user(dn_malformed) expected_malformed = { - 'distinguished_name': dn_malformed, - 'country': 'US', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_malformed, + "country": "US", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_malformed, expected_malformed) @@ -330,45 +379,49 @@ def test_distinguished_name_to_user_empty_and_malformed(self): dn_empty_val = "/C=DK/O=/CN=John Doe" user_dict_empty_val = distinguished_name_to_user(dn_empty_val) expected_empty_val = { - 'distinguished_name': dn_empty_val, - 'country': TEST_COUNTRY, - 'organization': '', - 'full_name': TEST_FULL_NAME + "distinguished_name": dn_empty_val, + "country": TEST_COUNTRY, + "organization": "", + "full_name": TEST_FULL_NAME, } self.assertEqual(user_dict_empty_val, expected_empty_val) def test_fill_distinguished_name_from_fields(self): """Test filling distinguished_name from other user fields.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - 'email': 'jane.doe@example.com' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + "email": "jane.doe@example.com", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=jane.doe@example.com" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=jane.doe@example.com" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_with_gdp(self): """Test filling distinguished_name with a GDP project field.""" user = { - 'full_name': 'Jane Doe', - 'organization': 'Test Corp', - 'country': TEST_COUNTRY, - gdp_distinguished_field: 'project_x' + "full_name": "Jane Doe", + "organization": "Test Corp", + "country": TEST_COUNTRY, + gdp_distinguished_field: "project_x", } fill_distinguished_name(user) - expected_dn = "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" \ - "/emailAddress=NA/GDP=project_x" - self.assertEqual(user['distinguished_name'], expected_dn) + expected_dn = ( + "/C=DK/ST=NA/L=NA/O=Test Corp/OU=NA/CN=Jane Doe" + "/emailAddress=NA/GDP=project_x" + ) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_distinguished_name_already_exists(self): """Test that an existing distinguished_name is not overwritten.""" user = { - 'distinguished_name': TEST_USER_ID, - 'full_name': 'Jane Doe', - 'country': 'US' + "distinguished_name": TEST_USER_ID, + "full_name": "Jane Doe", + "country": "US", } original_user = user.copy() returned_user = fill_distinguished_name(user) @@ -380,24 +433,21 @@ def test_fill_distinguished_name_empty_user(self): user = {} fill_distinguished_name(user) expected_dn = "/C=NA/ST=NA/L=NA/O=NA/OU=NA/CN=NA/emailAddress=NA" - self.assertEqual(user['distinguished_name'], expected_dn) + self.assertEqual(user["distinguished_name"], expected_dn) def test_fill_user_completes_dict(self): """Test that fill_user adds missing fields and preserves existing ones.""" - user = { - 'full_name': TEST_FULL_NAME, - 'extra_field': 'extra_value' - } + user = {"full_name": TEST_FULL_NAME, "extra_field": "extra_value"} fill_user(user) # Check that existing values are preserved - self.assertEqual(user['full_name'], TEST_FULL_NAME) - self.assertEqual(user['extra_field'], 'extra_value') + self.assertEqual(user["full_name"], TEST_FULL_NAME) + self.assertEqual(user["extra_field"], "extra_value") # Check that missing standard fields are added with empty strings - self.assertEqual(user['organization'], '') - self.assertEqual(user['country'], '') + self.assertEqual(user["organization"], "") + self.assertEqual(user["country"], "") # Check that all standard keys are present for key, _ in cert_field_order: @@ -409,7 +459,7 @@ def test_fill_user_with_empty_dict(self): fill_user(user) self.assertEqual(len(user), len(cert_field_order)) for key, _ in cert_field_order: - self.assertEqual(user[key], '') + self.assertEqual(user[key], "") def test_fill_user_modifies_in_place_and_returns_self(self): """Test that fill_user modifies the dictionary in-place and returns @@ -421,158 +471,171 @@ def test_fill_user_modifies_in_place_and_returns_self(self): def test_canonical_user_transformations(self): """Test canonical_user applies all transformations correctly.""" user_dict = { - 'full_name': ' john doe ', - 'email': 'John@DoE.ORG', - 'country': 'dk', - 'state': 'vt', - 'organization': ' Test Org ', - 'extra_field': 'should be removed', - 'id': 123 + "full_name": " john doe ", + "email": "John@DoE.ORG", + "country": "dk", + "state": "vt", + "organization": " Test Org ", + "extra_field": "should be removed", + "id": 123, } - limit_fields = ['full_name', 'email', - 'country', 'state', 'organization', 'id'] + limit_fields = [ + "full_name", + "email", + "country", + "state", + "organization", + "id", + ] canonical = canonical_user(self.configuration, user_dict, limit_fields) expected = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'country': TEST_COUNTRY, - 'state': TEST_STATE, - 'organization': TEST_ORGANIZATION, - 'id': 123 + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "country": TEST_COUNTRY, + "state": TEST_STATE, + "organization": TEST_ORGANIZATION, + "id": 123, } self.assertEqual(canonical, expected) - self.assertNotIn('extra_field', canonical) + self.assertNotIn("extra_field", canonical) def test_canonical_user_with_peers_legacy(self): """Test canonical_user_with_peers with legacy peers list""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') - self.assertEqual(canonical['peers_full_name'], 'Alice, Bob') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) + self.assertEqual(canonical["peers_full_name"], "Alice, Bob") def test_canonical_user_with_peers_explicit(self): """Test canonical_user_with_peers with explicit peers fields""" - self.configuration.site_peers_explicit_fields = ['email', 'full_name'] + self.configuration.site_peers_explicit_fields = ["email", "full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_email': 'custom@example.com', - 'peers_full_name': 'Custom Name', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_email": "custom@example.com", + "peers_full_name": "Custom Name", + "peers": [ + "/C=DK/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_email'], 'custom@example.com') - self.assertEqual(canonical['peers_full_name'], 'Custom Name') + self.assertEqual(canonical["peers_email"], "custom@example.com") + self.assertEqual(canonical["peers_full_name"], "Custom Name") def test_canonical_user_with_peers_mixed(self): """Test canonical_user_with_peers with mixed explicit and legacy peers""" self.configuration.site_peers_explicit_fields = [ - 'email', 'organization'] + "email", + "organization", + ] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL, - 'peers_organization': TEST_ORGANIZATION, - 'peers': [ - '/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com', - '/C=DK/CN=Bob/emailAddress=bob@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": TEST_EMAIL, + "peers_organization": TEST_ORGANIZATION, + "peers": [ + "/C=DK/O=Legacy Org/CN=Alice/emailAddress=alice@example.com", + "/C=DK/CN=Bob/emailAddress=bob@example.com", + ], } - limit_fields = ['full_name', 'email', 'organization'] + limit_fields = ["full_name", "email", "organization"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Explicit field should be preserved - self.assertEqual(canonical['peers_organization'], 'Test Org') + self.assertEqual(canonical["peers_organization"], "Test Org") # Legacy peers should be converted for email - self.assertEqual(canonical['peers_email'], - 'alice@example.com, bob@example.com') + self.assertEqual( + canonical["peers_email"], "alice@example.com, bob@example.com" + ) def test_canonical_user_with_peers_empty(self): """Test canonical_user_with_peers with no peers data""" - self.configuration.site_peers_explicit_fields = ['email'] - user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': TEST_EMAIL - } - limit_fields = ['full_name', 'email'] + self.configuration.site_peers_explicit_fields = ["email"] + user_dict = {"full_name": TEST_FULL_NAME, "email": TEST_EMAIL} + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers", canonical) def test_canonical_user_with_peers_no_explicit_fields(self): """Test canonical_user_with_peers with no peer fields configured""" self.configuration.site_peers_explicit_fields = [] user_dict = { - 'full_name': TEST_FULL_NAME, - 'email': 'john@example.com', - 'peers': [ - '/C=DK/CN=Alice/emailAddress=alice@example.com' - ] + "full_name": TEST_FULL_NAME, + "email": "john@example.com", + "peers": ["/C=DK/CN=Alice/emailAddress=alice@example.com"], } - limit_fields = ['full_name', 'email'] + limit_fields = ["full_name", "email"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) # Should not create any peer fields - self.assertNotIn('peers_email', canonical) - self.assertNotIn('peers_full_name', canonical) + self.assertNotIn("peers_email", canonical) + self.assertNotIn("peers_full_name", canonical) def test_canonical_user_with_peers_special_chars(self): """Test canonical_user_with_peers handles special characters in DNs""" - self.configuration.site_peers_explicit_fields = ['full_name'] + self.configuration.site_peers_explicit_fields = ["full_name"] user_dict = { - 'full_name': TEST_FULL_NAME, - 'peers': [ - '/C=DK/CN=Jérôme Müller', - '/C=DK/CN=O‘‘Reilly', - '/C=DK/CN=Alice "Ace" Smith' - ] + "full_name": TEST_FULL_NAME, + "peers": [ + "/C=DK/CN=Jérôme Müller", + "/C=DK/CN=O‘‘Reilly", + '/C=DK/CN=Alice "Ace" Smith', + ], } - limit_fields = ['full_name'] + limit_fields = ["full_name"] canonical = canonical_user_with_peers( - self.configuration, user_dict, limit_fields) + self.configuration, user_dict, limit_fields + ) - self.assertEqual(canonical['peers_full_name'], - 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith') + self.assertEqual( + canonical["peers_full_name"], + 'Jérôme Müller, O‘‘Reilly, Alice "Ace" Smith', + ) def test_canonical_user_unicode_name(self): """Test canonical_user with unicode characters in full_name.""" # Using a name that title() might mess up without unicode conversion - user_dict = {'full_name': u'josé de la vega'} - limit_fields = ['full_name'] + user_dict = {"full_name": "josé de la vega"} + limit_fields = ["full_name"] canonical = canonical_user(self.configuration, user_dict, limit_fields) - self.assertEqual(canonical['full_name'], u'José De La Vega') + self.assertEqual(canonical["full_name"], "José De La Vega") def test_canonical_user_empty_input(self): """Test canonical_user with empty inputs.""" self.assertEqual(canonical_user(self.configuration, {}, []), {}) - self.assertEqual(canonical_user(self.configuration, {'a': 1}, []), {}) - self.assertEqual(canonical_user(self.configuration, {}, ['a']), {}) + self.assertEqual(canonical_user(self.configuration, {"a": 1}, []), {}) + self.assertEqual(canonical_user(self.configuration, {}, ["a"]), {}) def test_generate_https_urls_single_method_cgi(self): """Test generate_https_urls with a single method and cgi-bin.""" - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/cgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -581,7 +644,7 @@ def test_generate_https_urls_single_method_cgi(self): def test_generate_https_urls_single_method_wsgi(self): """Test generate_https_urls with a single method and wsgi-bin.""" self.configuration.site_enable_wsgi = True - self.configuration.site_login_methods = ['migcert'] + self.configuration.site_login_methods = ["migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" expected = "https://mig.cert/wsgi-bin/script.py" result = generate_https_urls(self.configuration, template, {}) @@ -590,29 +653,32 @@ def test_generate_https_urls_single_method_wsgi(self): def test_generate_https_urls_multiple_methods(self): """Test generate_https_urls with multiple methods.""" template = "%(auto_base)s/%(auto_bin)s/script.py" - self.configuration.site_login_methods = ['migcert', 'extoidc'] + self.configuration.site_login_methods = ["migcert", "extoidc"] result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" expected_url2 = "https://ext.oidc/cgi-bin/script.py" expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_with_helper_dict(self): """Test generate_https_urls with a helper_dict.""" - self.configuration.site_login_methods = ['extoid'] + self.configuration.site_login_methods = ["extoid"] template = "%(auto_base)s/%(auto_bin)s/%(script)s" - helper = {'script': 'login.py'} + helper = {"script": "login.py"} result = generate_https_urls(self.configuration, template, helper) self.assertEqual(result, "https://ext.oid/cgi-bin/login.py") def test_generate_https_urls_method_enabled_but_url_missing(self): """Test that methods with no configured URL are skipped.""" self.configuration.migserver_https_ext_cert_url = "" # URL is missing - self.configuration.site_login_methods = ['migcert', 'extcert'] + self.configuration.site_login_methods = ["migcert", "extcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) self.assertEqual(result, "https://mig.cert/cgi-bin/script.py") @@ -626,7 +692,7 @@ def test_generate_https_urls_no_methods_enabled(self): def test_generate_https_urls_respects_order(self): """Test that the order of site_login_methods is respected.""" - self.configuration.site_login_methods = ['extoidc', 'migcert'] + self.configuration.site_login_methods = ["extoidc", "migcert"] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://ext.oidc/cgi-bin/script.py" @@ -634,14 +700,20 @@ def test_generate_https_urls_respects_order(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_generate_https_urls_avoids_duplicates(self): """Test that duplicate URLs are not generated.""" self.configuration.site_login_methods = [ - 'migcert', 'extoidc', 'migcert'] + "migcert", + "extoidc", + "migcert", + ] template = "%(auto_base)s/%(auto_bin)s/script.py" result = generate_https_urls(self.configuration, template, {}) expected_url1 = "https://mig.cert/cgi-bin/script.py" @@ -649,22 +721,35 @@ def test_generate_https_urls_avoids_duplicates(self): expected_note = """ (The URL depends on whether you log in with OpenID or a user certificate - just use the one that looks most familiar or try them in turn)""" - expected_result = "%s\nor\n%s%s" % (expected_url1, expected_url2, - expected_note) + expected_result = "%s\nor\n%s%s" % ( + expected_url1, + expected_url2, + expected_note, + ) self.assertEqual(result, expected_result) def test_auth_type_description_all(self): """Test auth_type_description returns full dict when requested""" from mig.shared.defaults import keyword_all + result = auth_type_description(self.configuration, keyword_all) - expected_keys = ['migoid', 'migoidc', 'migcert', 'extoid', 'extoidc', - 'extcert'] + expected_keys = [ + "migoid", + "migoidc", + "migcert", + "extoid", + "extoidc", + "extcert", + ] self.assertEqual(sorted(result.keys()), sorted(expected_keys)) def test_auth_type_description_individual(self): """Test auth_type_description returns expected strings for each type""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) # Setup titles in configuration self.configuration.user_mig_oid_title = "MiG OpenID" @@ -673,31 +758,37 @@ def test_auth_type_description_individual(self): self.configuration.user_ext_cert_title = "External Certificate" test_cases = [ - ('migoid', 'MiG OpenID %s' % AUTH_OPENID_V2), - ('migoidc', 'MiG OpenID %s' % AUTH_OPENID_CONNECT), - ('migcert', 'MiG Certificate %s' % AUTH_CERTIFICATE), - ('extoid', 'External OpenID %s' % AUTH_OPENID_V2), - ('extoidc', 'External OpenID %s' % AUTH_OPENID_CONNECT), - ('extcert', 'External Certificate %s' % AUTH_CERTIFICATE), + ("migoid", "MiG OpenID %s" % AUTH_OPENID_V2), + ("migoidc", "MiG OpenID %s" % AUTH_OPENID_CONNECT), + ("migcert", "MiG Certificate %s" % AUTH_CERTIFICATE), + ("extoid", "External OpenID %s" % AUTH_OPENID_V2), + ("extoidc", "External OpenID %s" % AUTH_OPENID_CONNECT), + ("extcert", "External Certificate %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) def test_auth_type_description_unknown(self): """Test auth_type_description returns 'UNKNOWN' for invalid types""" - self.assertEqual(auth_type_description(self.configuration, 'invalid'), - 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, ''), 'UNKNOWN') - self.assertEqual(auth_type_description( - self.configuration, None), 'UNKNOWN') + self.assertEqual( + auth_type_description(self.configuration, "invalid"), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, ""), "UNKNOWN" + ) + self.assertEqual( + auth_type_description(self.configuration, None), "UNKNOWN" + ) def test_auth_type_description_empty_titles(self): """Test auth_type_description handles empty titles in configuration""" - from mig.shared.defaults import AUTH_CERTIFICATE, AUTH_OPENID_CONNECT, \ - AUTH_OPENID_V2 + from mig.shared.defaults import ( + AUTH_CERTIFICATE, + AUTH_OPENID_CONNECT, + AUTH_OPENID_V2, + ) self.configuration.user_mig_oid_title = "" self.configuration.user_mig_cert_title = "" @@ -705,15 +796,15 @@ def test_auth_type_description_empty_titles(self): self.configuration.user_ext_cert_title = "" test_cases = [ - ('migoid', ' %s' % AUTH_OPENID_V2), - ('migoidc', ' %s' % AUTH_OPENID_CONNECT), - ('migcert', ' %s' % AUTH_CERTIFICATE), - ('extoid', ' %s' % AUTH_OPENID_V2), - ('extoidc', ' %s' % AUTH_OPENID_CONNECT), - ('extcert', ' %s' % AUTH_CERTIFICATE), + ("migoid", " %s" % AUTH_OPENID_V2), + ("migoidc", " %s" % AUTH_OPENID_CONNECT), + ("migcert", " %s" % AUTH_CERTIFICATE), + ("extoid", " %s" % AUTH_OPENID_V2), + ("extoidc", " %s" % AUTH_OPENID_CONNECT), + ("extcert", " %s" % AUTH_CERTIFICATE), ] - for (auth_type, expected) in test_cases: + for auth_type, expected in test_cases: result = auth_type_description(self.configuration, auth_type) self.assertEqual(result, expected) @@ -721,11 +812,16 @@ def test_allow_script_gdp_enabled_anonymous_allowed(self): """Test allow_script with GDP enabled, anonymous user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' # Use a valid script or a default + script_name = ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) # Use a valid script or a default if not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_anon_scripts is empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_anon_scripts is empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script(self.configuration, script_name, None) self.assertTrue(allow) self.assertEqual(msg, "") @@ -734,29 +830,41 @@ def test_allow_script_gdp_enabled_anonymous_disallowed(self): """Test allow_script with GDP enabled, anonymous user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_anon_scripts if script_name in valid_gdp_anon_scripts: valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script(self.configuration, script_name, None) self.assertFalse(allow) - self.assertEqual(msg, "anonoymous access to functionality disabled " - "by site configuration!") + self.assertEqual( + msg, + "anonoymous access to functionality disabled " + "by site configuration!", + ) def test_allow_script_gdp_enabled_authenticated_allowed(self): """Test allow_script with GDP enabled, authenticated user, and script allowed.""" self.configuration.site_enable_gdp = True - script_name = valid_gdp_auth_scripts[0] if valid_gdp_auth_scripts \ - else valid_gdp_anon_scripts[0] if valid_gdp_anon_scripts \ - else 'allowed_script.py' + script_name = ( + valid_gdp_auth_scripts[0] + if valid_gdp_auth_scripts + else ( + valid_gdp_anon_scripts[0] + if valid_gdp_anon_scripts + else "allowed_script.py" + ) + ) if not valid_gdp_auth_scripts and not valid_gdp_anon_scripts: - print("WARNING: valid_gdp_auth_scripts and " - "valid_gdp_anon_scripts are empty. Using " - "'allowed_script.py' which may cause a test failure.") + print( + "WARNING: valid_gdp_auth_scripts and " + "valid_gdp_anon_scripts are empty. Using " + "'allowed_script.py' which may cause a test failure." + ) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") @@ -764,7 +872,7 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): """Test allow_script with GDP enabled, authenticated user, and script disallowed.""" self.configuration.site_enable_gdp = True - script_name = 'disallowed_script.py' + script_name = "disallowed_script.py" # Ensure the script is not in valid_gdp_auth_scripts or # valid_gdp_anon_scripts @@ -774,98 +882,108 @@ def test_allow_script_gdp_enabled_authenticated_disallowed(self): valid_gdp_anon_scripts.remove(script_name) allow, msg = allow_script( - self.configuration, script_name, 'test_client') + self.configuration, script_name, "test_client" + ) self.assertFalse(allow) - self.assertEqual(msg, "all access to functionality disabled by site " - "configuration!") + self.assertEqual( + msg, + "all access to functionality disabled by site " "configuration!", + ) def test_allow_script_gdp_disabled(self): """Test allow_script with GDP disabled.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', - 'test_client') + allow, msg = allow_script( + self.configuration, "any_script.py", "test_client" + ) self.assertTrue(allow) self.assertEqual(msg, "") def test_allow_script_gdp_disabled_anonymous(self): """Test allow_script with GDP disabled and anonymous user.""" self.configuration.site_enable_gdp = False - allow, msg = allow_script(self.configuration, 'any_script.py', None) + allow, msg = allow_script(self.configuration, "any_script.py", None) self.assertTrue(allow) self.assertEqual(msg, "") def test_requested_page_normal(self): """Test requested_page with basic environment""" fake_env = { - 'SCRIPT_NAME': '/cgi-bin/home.py', - 'REQUEST_URI': '/cgi-bin/home.py' + "SCRIPT_NAME": "/cgi-bin/home.py", + "REQUEST_URI": "/cgi-bin/home.py", } - self.assertEqual(requested_page(fake_env), '/cgi-bin/home.py') + self.assertEqual(requested_page(fake_env), "/cgi-bin/home.py") def test_requested_page_name_only(self): """Test requested_page with name_only argument""" fake_env = { - 'BACKEND_NAME': 'search.py', - 'PATH_INFO': '/cgi-bin/search.py/path' + "BACKEND_NAME": "search.py", + "PATH_INFO": "/cgi-bin/search.py/path", } - result = requested_page(fake_env, name_only=True, fallback='fallback') - self.assertEqual(result, 'search.py') + result = requested_page(fake_env, name_only=True, fallback="fallback") + self.assertEqual(result, "search.py") def test_requested_page_strip_extension(self): """Test requested_page with strip_ext argument""" - fake_env = {'REQUEST_URI': '/cgi-bin/file.py?query=val'} + fake_env = {"REQUEST_URI": "/cgi-bin/file.py?query=val"} result = requested_page(fake_env, strip_ext=True) - self.assertEqual(result, '/cgi-bin/file') + self.assertEqual(result, "/cgi-bin/file") def test_requested_page_priority(self): """Test environment variable priority order""" _init_env = { - 'BACKEND_NAME': 'backend.py', - 'SCRIPT_URL': '/cgi-bin/script_url.py', - 'SCRIPT_URI': 'https://host/cgi-bin/script_uri.py', - 'PATH_INFO': '/cgi-bin/path_info.py', - 'REQUEST_URI': '/cgi-bin/req_uri.py' + "BACKEND_NAME": "backend.py", + "SCRIPT_URL": "/cgi-bin/script_url.py", + "SCRIPT_URI": "https://host/cgi-bin/script_uri.py", + "PATH_INFO": "/cgi-bin/path_info.py", + "REQUEST_URI": "/cgi-bin/req_uri.py", } - priority_order = ['BACKEND_NAME', 'SCRIPT_URL', 'SCRIPT_URI', - 'PATH_INFO', 'REQUEST_URI'] + priority_order = [ + "BACKEND_NAME", + "SCRIPT_URL", + "SCRIPT_URI", + "PATH_INFO", + "REQUEST_URI", + ] for var in priority_order: # Reset fake_env each time fake_env = dict([pair for pair in _init_env.items()]) current_env = {var: fake_env[var]} - if var != 'SCRIPT_URI': + if var != "SCRIPT_URI": expected = fake_env[var] else: - expected = 'https://host/cgi-bin/script_uri.py' + expected = "https://host/cgi-bin/script_uri.py" result = requested_page(current_env) - self.assertEqual(result, expected, - "failed priority for %s" % var) + self.assertEqual(result, expected, "failed priority for %s" % var) # Remove higher priority variables one by one - for higher_var in priority_order[:priority_order.index(var)]: + for higher_var in priority_order[: priority_order.index(var)]: del fake_env[higher_var] result = requested_page(fake_env) - self.assertEqual(result, fake_env[var], - "failed fallthrough to %s" % var) + self.assertEqual( + result, fake_env[var], "failed fallthrough to %s" % var + ) def test_requested_page_unsafe_filter(self): """Test unsafe character filtering""" test_cases = [ - ('/cgi-bin/unsafe.py' - fake_env = {'REQUEST_URI': dangerous} + dangerous = "/cgi-bin/unsafe.py" + fake_env = {"REQUEST_URI": dangerous} unsafe_result = requested_page(fake_env, include_unsafe=True) self.assertEqual(unsafe_result, dangerous) @@ -874,39 +992,39 @@ def test_requested_page_include_unsafe(self): def test_requested_page_query_stripping(self): """Test removal of query parameters""" - test_input = '/cgi-bin/script.py?query=value¶m=data' - fake_env = {'REQUEST_URI': test_input} + test_input = "/cgi-bin/script.py?query=value¶m=data" + fake_env = {"REQUEST_URI": test_input} result = requested_page(fake_env) - self.assertEqual(result, '/cgi-bin/script.py') + self.assertEqual(result, "/cgi-bin/script.py") def test_requested_page_fallback(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' + fallback = "special.py" result = requested_page(fake_env, fallback=fallback) self.assertEqual(result, fallback) def test_requested_page_fallback_despite_os_environ_value(self): """Test fallback to default""" fake_env = {} - fallback = 'special.py' - os.environ['BACKEND_NAME'] = 'BOGUS' + fallback = "special.py" + os.environ["BACKEND_NAME"] = "BOGUS" result = requested_page(fake_env, fallback=fallback) - del os.environ['BACKEND_NAME'] + del os.environ["BACKEND_NAME"] self.assertEqual(result, fallback) def test_requested_url_base_normal(self): """Test requested_url_base with basic complete URL""" - fake_env = {'SCRIPT_URI': 'https://example.com/path/to/script.py'} + fake_env = {"SCRIPT_URI": "https://example.com/path/to/script.py"} result = requested_url_base(fake_env) - expected = 'https://example.com' + expected = "https://example.com" self.assertEqual(result, expected) def test_requested_url_base_custom_field(self): """Test requested_url_base with custom uri_field parameter""" - fake_env = {'CUSTOM_FIELD_URI': 'http://server.org:8001/base/'} - result = requested_url_base(fake_env, uri_field='CUSTOM_FIELD_URI') - expected = 'http://server.org:8001' + fake_env = {"CUSTOM_FIELD_URI": "http://server.org:8001/base/"} + result = requested_url_base(fake_env, uri_field="CUSTOM_FIELD_URI") + expected = "http://server.org:8001" self.assertEqual(result, expected) # TODO: adjust tested function to bail out on missing uri_field @@ -915,56 +1033,55 @@ def test_requested_url_base_missing(self): """Test requested_url_base when uri_field not present""" fake_env = {} result = requested_url_base(fake_env) - self.assertEqual(result, '') + self.assertEqual(result, "") def test_requested_url_base_safe_filter(self): """Test unsafe character filtering in url base""" - test_url = 'https://user:pass@evil.com/' - fake_env = {'SCRIPT_URI': test_url} + test_url = "https://user:pass@evil.com/" + fake_env = {"SCRIPT_URI": test_url} safe_result = requested_url_base(fake_env) - expected_safe = 'https://user:passevil.com' + expected_safe = "https://user:passevil.com" self.assertEqual(safe_result, expected_safe) def test_requested_url_base_include_unsafe(self): """Test include_unsafe argument behavior""" - test_url = 'http://[::1]?' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://[::1]?" + fake_env = {"SCRIPT_URI": test_url} unsafe_result = requested_url_base(fake_env, include_unsafe=True) - self.assertEqual(unsafe_result, 'http://[::1]?') + self.assertEqual(unsafe_result, "http://[::1]?") safe_result = requested_url_base(fake_env, include_unsafe=False) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") safe_result = requested_url_base(fake_env) - self.assertEqual(safe_result, 'http://::1markup') + self.assertEqual(safe_result, "http://::1markup") def test_requested_url_base_split_valid_edge_cases(self): """Test URL base splitting on valid edge cases""" test_cases = [ - ('https://site.com', 'https://site.com'), - ('http://a/single/slash', 'http://a'), - ('file:///absolute/path', 'file://'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("https://site.com", "https://site.com"), + ("http://a/single/slash", "http://a"), + ("file:///absolute/path", "file://"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} result = requested_url_base(fake_env) - self.assertEqual(result, expected, - "failed for %s" % input_url) + self.assertEqual(result, expected, "failed for %s" % input_url) # TODO: adjust function to bail out on invalid URLs and enable next @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_split_invalid_edge_cases(self): """Test URL base splitting on invalid edge cases""" test_cases = [ - ('', ''), - ('/', '/'), - ('/single', '/single'), - ('/double/slash', '/double/slash'), - ('invalid.proto:/1st/2nd/slash', 'invalid.proto:/1st/2nd/slash'), - ('invalid.proto://double/slash', 'invalid.proto://double') + ("", ""), + ("/", "/"), + ("/single", "/single"), + ("/double/slash", "/double/slash"), + ("invalid.proto:/1st/2nd/slash", "invalid.proto:/1st/2nd/slash"), + ("invalid.proto://double/slash", "invalid.proto://double"), ] - for (input_url, expected) in test_cases: - fake_env = {'SCRIPT_URI': input_url} + for input_url, expected in test_cases: + fake_env = {"SCRIPT_URI": input_url} try: result = requested_url_base(fake_env) except ValueError: @@ -975,7 +1092,7 @@ def test_requested_url_base_split_invalid_edge_cases(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_relative_path(self): """Test relative paths in URL""" - fake_env = {'SCRIPT_URI': '/cgi-bin/script.py'} + fake_env = {"SCRIPT_URI": "/cgi-bin/script.py"} try: result = requested_url_base(fake_env) except ValueError: @@ -986,10 +1103,10 @@ def test_requested_url_base_relative_path(self): @unittest.skipIf(True, "requires fix in tested function") def test_requested_url_base_special_chars(self): """Test handling of special characters in URL""" - test_url = 'http://üñîçøðê.net/path' - fake_env = {'SCRIPT_URI': test_url} + test_url = "http://üñîçøðê.net/path" + fake_env = {"SCRIPT_URI": test_url} result = requested_url_base(fake_env) - self.assertEqual(result, 'http://üñîçøðê.net') + self.assertEqual(result, "http://üñîçøðê.net") def test_verify_local_url_direct_match(self): """Test verify_local_url with direct match to known site URL""" @@ -1005,7 +1122,9 @@ def test_verify_local_url_subpath_match(self): def test_verify_local_url_public_alias(self): """Test verify_local_url with public alias domain""" - self.configuration.migserver_public_alias_url = "https://grid.example.org" + self.configuration.migserver_public_alias_url = ( + "https://grid.example.org" + ) test_url = "https://grid.example.org/cgi-bin/file.py" self.assertTrue(verify_local_url(self.configuration, test_url)) @@ -1017,103 +1136,122 @@ def test_verify_local_url_absolute_path(self): def test_verify_local_url_relative_path(self): """Test verify_local_url with relative path""" test_url = "subdir/script.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_external_domain(self): """Test verify_local_url rejects external domains""" test_url = "https://evil.com/malicious.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_invalid_url(self): """Test verify_local_url rejects invalid/malformed URLs""" test_url = "javascript:alert('xss')" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_missing_https(self): """Test verify_local_url with HTTP when only HTTPS supported""" test_url = "http://mig.cert/cgi-bin/home.py" self.configuration.migserver_https_mig_cert_url = "https://mig.cert" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_verify_local_url_different_port(self): """Test verify_local_url rejects same domain with different port""" self.configuration.migserver_https_ext_cert_url = "https://ext.cert:443" test_url = "https://ext.cert:444/cgi-bin/file.py" - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: status = verify_local_url(self.configuration, test_url) self.assertFalse(status) - self.assertTrue(any('request verification failed' in msg for msg in - log_capture.output)) + self.assertTrue( + any( + "request verification failed" in msg + for msg in log_capture.output + ) + ) def test_invisible_path_file(self): """Test invisible_path detects names in invisible files""" - invisible_filename = '.htaccess' - visible_filename = 'visible.txt' + invisible_filename = ".htaccess" + visible_filename = "visible.txt" # Test with invisible filename - self.assertTrue(invisible_path('/some/path/%s' % invisible_filename)) + self.assertTrue(invisible_path("/some/path/%s" % invisible_filename)) self.assertTrue(invisible_path(invisible_filename)) self.assertTrue(invisible_path(invisible_filename, True)) # Test with visible filename self.assertFalse(invisible_path(visible_filename)) - self.assertFalse(invisible_path('/some/path/%s' % visible_filename)) + self.assertFalse(invisible_path("/some/path/%s" % visible_filename)) def test_invisible_path_dir(self): """Test invisible_path detects paths in invisible dir""" - invisible_dirname = '.vgridscm' - visible_dirname = 'somedir' + invisible_dirname = ".vgridscm" + visible_dirname = "somedir" # Test different forms of invisible directory path self.assertTrue(invisible_path(invisible_dirname)) - self.assertTrue(invisible_path('/%s' % invisible_dirname)) - self.assertTrue(invisible_path('/parent/%s' % invisible_dirname)) - self.assertTrue(invisible_path('%s/sub' % invisible_dirname)) - self.assertTrue(invisible_path('/%s/file' % invisible_dirname)) + self.assertTrue(invisible_path("/%s" % invisible_dirname)) + self.assertTrue(invisible_path("/parent/%s" % invisible_dirname)) + self.assertTrue(invisible_path("%s/sub" % invisible_dirname)) + self.assertTrue(invisible_path("/%s/file" % invisible_dirname)) # Test visible directory self.assertFalse(invisible_path(visible_dirname)) - self.assertFalse(invisible_path('/%s' % visible_dirname)) - self.assertFalse(invisible_path('/parent/%s' % visible_dirname)) + self.assertFalse(invisible_path("/%s" % visible_dirname)) + self.assertFalse(invisible_path("/parent/%s" % visible_dirname)) def test_invisible_path_vgrid_exception(self): """Test allow_vgrid_scripts excludes valid vgrid xgi scripts""" - invisible_dirname = '.vgridscm' - vgrid_script = '.vgridscm/cgi-bin/hgweb.cgi' + invisible_dirname = ".vgridscm" + vgrid_script = ".vgridscm/cgi-bin/hgweb.cgi" test_patterns = [ - '/%s/%s' % (invisible_dirname, vgrid_script), - '/root/%s/sub/%s' % (invisible_dirname, vgrid_script), - '/%s/prefix%ssuffix' % (invisible_dirname, vgrid_script), - '/%s/similar_script.py' % invisible_dirname, - '/path/to/%s' % vgrid_script - + "/%s/%s" % (invisible_dirname, vgrid_script), + "/root/%s/sub/%s" % (invisible_dirname, vgrid_script), + "/%s/prefix%ssuffix" % (invisible_dirname, vgrid_script), + "/%s/similar_script.py" % invisible_dirname, + "/path/to/%s" % vgrid_script, ] test_expects = [False, False, False, True, False] test_iter = zip(test_patterns, test_expects) - for (i, (path, expected)) in enumerate(test_iter): + for i, (path, expected) in enumerate(test_iter): self.assertEqual( invisible_path(path, allow_vgrid_scripts=True), expected, "test case %d: path %r should %sbe invisible with scripts" - % (i, path, "" if expected else "not ") + % (i, path, "" if expected else "not "), ) # Should still be invisible without exception flag @@ -1122,23 +1260,24 @@ def test_invisible_path_vgrid_exception(self): invisible_path(path, allow_vgrid_scripts=False), expect_no_exception, "test case %d: path %r should %sbe invisible without scripts" - % (i, path, "" if expect_no_exception else "not ") + % (i, path, "" if expect_no_exception else "not "), ) def test_invisible_path_edge_cases(self): """Test invisible_path handles edge cases""" from mig.shared.defaults import _user_invisible_dirs + invisible_dirname = _user_invisible_dirs[0] # Empty path - self.assertFalse(invisible_path('')) - self.assertFalse(invisible_path('', allow_vgrid_scripts=True)) + self.assertFalse(invisible_path("")) + self.assertFalse(invisible_path("", allow_vgrid_scripts=True)) # Root path - self.assertFalse(invisible_path('/')) + self.assertFalse(invisible_path("/")) # Path that only contains invisible directory substring - substring_path = '/prefix%ssuffix/file' % invisible_dirname + substring_path = "/prefix%ssuffix/file" % invisible_dirname self.assertFalse(invisible_path(substring_path)) def test_client_alias(self): @@ -1176,21 +1315,21 @@ def test_get_short_id_with_gdp(self): def test_get_user_id_x509_format(self): """Test get_user_id returns DN for X509 format""" self.configuration.site_user_id_format = "X509" - user = {'distinguished_name': TEST_USER_ID} + user = {"distinguished_name": TEST_USER_ID} result = get_user_id(self.configuration, user) self.assertEqual(result, TEST_USER_ID) def test_get_user_id_uuid_format(self): """Test get_user_id returns UUID when configured""" self.configuration.site_user_id_format = "UUID" - user = {'unique_id': "123e4567-e89b-12d3-a456-426614174000"} + user = {"unique_id": "123e4567-e89b-12d3-a456-426614174000"} result = get_user_id(self.configuration, user) self.assertEqual(result, "123e4567-e89b-12d3-a456-426614174000") def test_get_client_id(self): """Test get_client_id extracts DN from user dict""" test_dn = "/C=US/CN=Alice" - user = {'distinguished_name': test_dn, 'other': 'field'} + user = {"distinguished_name": test_dn, "other": "field"} result = get_client_id(user) self.assertEqual(result, test_dn) @@ -1204,14 +1343,8 @@ def test_hexlify_unhexlify_roundtrip(self): def test_is_gdp_user_detection(self): """Test is_gdp_user detects GDP project presence""" - self.assertTrue(is_gdp_user( - self.configuration, - "/GDP_PROJ=12345" - )) - self.assertFalse(is_gdp_user( - self.configuration, - "/CN=Regular User" - )) + self.assertTrue(is_gdp_user(self.configuration, "/GDP_PROJ=12345")) + self.assertFalse(is_gdp_user(self.configuration, "/CN=Regular User")) def test_sandbox_resource_identification(self): """Test sandbox_resource identifies sandboxes""" @@ -1235,8 +1368,8 @@ def test_invisible_dir_detection(self): def test_requested_backend_extraction(self): """Test requested_backend extracts backend name from environ""" test_env = { - 'BACKEND_NAME': '/cgi-bin/fileman.py', - 'PATH_TRANSLATED': '/wsgi-bin/fileman.py' + "BACKEND_NAME": "/cgi-bin/fileman.py", + "PATH_TRANSLATED": "/wsgi-bin/fileman.py", } result = requested_backend(test_env) self.assertEqual(result, "fileman") @@ -1258,15 +1391,16 @@ def test_get_xgi_bin_wsgi_vs_cgi(self): """Test get_xgi_bin returns correct script bin based on config""" # Test WSGI enabled self.configuration.site_enable_wsgi = True - self.assertEqual(get_xgi_bin(self.configuration), 'wsgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "wsgi-bin") # Test WSGI disabled self.configuration.site_enable_wsgi = False - self.assertEqual(get_xgi_bin(self.configuration), 'cgi-bin') + self.assertEqual(get_xgi_bin(self.configuration), "cgi-bin") # Test legacy force - self.assertEqual(get_xgi_bin(self.configuration, force_legacy=True), - 'cgi-bin') + self.assertEqual( + get_xgi_bin(self.configuration, force_legacy=True), "cgi-bin" + ) def test_valid_dir_input(self): """Test valid_dir_input prevents path traversal attempts""" @@ -1277,11 +1411,11 @@ def test_valid_dir_input(self): ("../illegal", False), ("/absolute", False), ] - for (relative_path, expected) in test_cases: + for relative_path, expected in test_cases: self.assertEqual( valid_dir_input(base, relative_path), expected, - "failed for %s" % relative_path + "failed for %s" % relative_path, ) def test_user_base_dir(self): @@ -1306,32 +1440,56 @@ def test_brief_list(self): # List longer than max_entries gets shortened long_list = list(range(15)) - expected_long = [0, 1, 2, 3, 4, ' ... shortened ... ', 10, 11, 12, - 13, 14] + expected_long = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 10, + 11, + 12, + 13, + 14, + ] self.assertEqual(brief_list(long_list), expected_long) # Custom max_entries with odd number custom_odd_list = list(range(10)) - expected_odd = [0, 1, 2, 3, ' ... shortened ... ', 6, 7, 8, 9] + expected_odd = [0, 1, 2, 3, " ... shortened ... ", 6, 7, 8, 9] self.assertEqual(brief_list(custom_odd_list, 9), expected_odd) # Range objects should be handled properly input_range = range(20) - expected_range = [0, 1, 2, 3, 4, ' ... shortened ... ', 15, 16, 17, - 18, 19] + expected_range = [ + 0, + 1, + 2, + 3, + 4, + " ... shortened ... ", + 15, + 16, + 17, + 18, + 19, + ] self.assertEqual(brief_list(input_range), expected_range) # Edge case - max_entries=2 - self.assertEqual(brief_list([1, 2, 3, 4], 2), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 2), [1, " ... shortened ... ", 4] + ) # Very small max_entries - self.assertEqual(brief_list([1, 2, 3, 4], 3), - [1, ' ... shortened ... ', 4]) + self.assertEqual( + brief_list([1, 2, 3, 4], 3), [1, " ... shortened ... ", 4] + ) # Non-integer input - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - expected_str = ['a', 'b', 'c', ' ... shortened ... ', 'e', 'f', 'g'] + str_list = ["a", "b", "c", "d", "e", "f", "g"] + expected_str = ["a", "b", "c", " ... shortened ... ", "e", "f", "g"] self.assertEqual(brief_list(str_list, 7), str_list) # At max_entries # TODO: fix tested function to handle these and enable test @@ -1339,19 +1497,20 @@ def test_brief_list(self): def test_brief_list_edge_cases(self): """Test brief_list helper function for compact list on edge cases""" # Edge case - max_entries=1 - self.assertEqual(brief_list([1, 2, 3], 1), [' ... shortened ... ']) + self.assertEqual(brief_list([1, 2, 3], 1), [" ... shortened ... "]) # Edge case - even short number of max_entries - str_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - self.assertEqual(brief_list(str_list, 6), - ['a', 'b', ' ... shortened ... ', 'f', 'g']) + str_list = ["a", "b", "c", "d", "e", "f", "g"] + self.assertEqual( + brief_list(str_list, 6), ["a", "b", " ... shortened ... ", "f", "g"] + ) class TestMigSharedBase__legacy_main(MigTestCase): """Run mig.shared.base legacy self-test""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): @@ -1362,13 +1521,19 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): """Keep track of printed output""" raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + ) diff --git a/tests/test_mig_shared_cloud.py b/tests/test_mig_shared_cloud.py index 241a79970..00e473b99 100644 --- a/tests/test_mig_shared_cloud.py +++ b/tests/test_mig_shared_cloud.py @@ -34,63 +34,74 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.cloud import cloud_load_instance, cloud_save_instance, \ - allowed_cloud_images - -DUMMY_USER = 'dummy-user' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +from mig.shared.cloud import ( + allowed_cloud_images, + cloud_load_instance, + cloud_save_instance, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) + +DUMMY_USER = "dummy-user" +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) DUMMY_CLOUD = "CLOUD" -DUMMY_FLAVOR = 'openstack' -DUMMY_LABEL = 'dummy-label' -DUMMY_IMAGE = 'dummy-image' -DUMMY_HEX_ID = 'deadbeef-dead-beef-dead-beefdeadbeef' - -DUMMY_CLOUD_SPEC = {'service_title': 'CLOUDTITLE', 'service_name': 'CLOUDNAME', - 'service_desc': 'A Cloud for migrid site', - 'service_provider_flavor': 'openstack', - 'service_hosts': 'https://myopenstack-cloud.org:5000/v3', - 'service_rules_of_conduct': 'rules-of-conduct.pdf', - 'service_max_user_instances': '0', - 'service_max_user_instances_map': {DUMMY_USER: '1'}, - 'service_allowed_images': DUMMY_IMAGE, - 'service_allowed_images_map': {DUMMY_USER: 'ALL'}, - 'service_user_map': {DUMMY_IMAGE, 'user'}, - 'service_image_alias_map': {DUMMY_IMAGE.lower(): - DUMMY_IMAGE}, - 'service_flavor_id': DUMMY_HEX_ID, - 'service_flavor_id_map': {DUMMY_USER: DUMMY_HEX_ID}, - 'service_network_id': DUMMY_HEX_ID, - 'service_key_id_map': {}, - 'service_sec_group_id': DUMMY_HEX_ID, - 'service_floating_network_id': DUMMY_HEX_ID, - 'service_availability_zone': 'myopenstack', - 'service_jumphost_address': 'jumphost.somewhere.org', - 'service_jumphost_user': 'cloud', - 'service_jumphost_manage_keys_script': - 'cloud_manage_keys.py', - 'service_jumphost_manage_keys_coding': 'base16', - 'service_network_id_map': {}, - 'service_sec_group_id_map': {}, - 'service_floating_network_id_map': {}, - 'service_availability_zone_map': {}, - 'service_jumphost_address_map': {}, - 'service_jumphost_user_map': {}} -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - site_cloud_access=[('distinguished_name', '.*')], - cloud_services=[DUMMY_CLOUD_SPEC]) - -DUMMY_INSTANCE_ID = '%s:%s:%s' % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) +DUMMY_FLAVOR = "openstack" +DUMMY_LABEL = "dummy-label" +DUMMY_IMAGE = "dummy-image" +DUMMY_HEX_ID = "deadbeef-dead-beef-dead-beefdeadbeef" + +DUMMY_CLOUD_SPEC = { + "service_title": "CLOUDTITLE", + "service_name": "CLOUDNAME", + "service_desc": "A Cloud for migrid site", + "service_provider_flavor": "openstack", + "service_hosts": "https://myopenstack-cloud.org:5000/v3", + "service_rules_of_conduct": "rules-of-conduct.pdf", + "service_max_user_instances": "0", + "service_max_user_instances_map": {DUMMY_USER: "1"}, + "service_allowed_images": DUMMY_IMAGE, + "service_allowed_images_map": {DUMMY_USER: "ALL"}, + "service_user_map": {DUMMY_IMAGE, "user"}, + "service_image_alias_map": {DUMMY_IMAGE.lower(): DUMMY_IMAGE}, + "service_flavor_id": DUMMY_HEX_ID, + "service_flavor_id_map": {DUMMY_USER: DUMMY_HEX_ID}, + "service_network_id": DUMMY_HEX_ID, + "service_key_id_map": {}, + "service_sec_group_id": DUMMY_HEX_ID, + "service_floating_network_id": DUMMY_HEX_ID, + "service_availability_zone": "myopenstack", + "service_jumphost_address": "jumphost.somewhere.org", + "service_jumphost_user": "cloud", + "service_jumphost_manage_keys_script": "cloud_manage_keys.py", + "service_jumphost_manage_keys_coding": "base16", + "service_network_id_map": {}, + "service_sec_group_id_map": {}, + "service_floating_network_id_map": {}, + "service_availability_zone_map": {}, + "service_jumphost_address_map": {}, + "service_jumphost_user_map": {}, +} +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + site_cloud_access=[("distinguished_name", ".*")], + cloud_services=[DUMMY_CLOUD_SPEC], +) + +DUMMY_INSTANCE_ID = "%s:%s:%s" % (DUMMY_USER, DUMMY_LABEL, DUMMY_HEX_ID) DUMMY_INSTANCE_DICT = { DUMMY_INSTANCE_ID: { - 'INSTANCE_LABEL': DUMMY_LABEL, - 'INSTANCE_IMAGE': DUMMY_IMAGE, - 'INSTANCE_ID': DUMMY_INSTANCE_ID, - 'IMAGE_ID': DUMMY_IMAGE, - 'CREATED_TIMESTAMP': "%d" % time.time(), - 'USER_CERT': DUMMY_USER + "INSTANCE_LABEL": DUMMY_LABEL, + "INSTANCE_IMAGE": DUMMY_IMAGE, + "INSTANCE_ID": DUMMY_INSTANCE_ID, + "IMAGE_ID": DUMMY_IMAGE, + "CREATED_TIMESTAMP": "%d" % time.time(), + "USER_CERT": DUMMY_USER, } } @@ -102,38 +113,46 @@ def test_cloud_save_load(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - save_status = cloud_save_instance(DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, - DUMMY_LABEL, DUMMY_INSTANCE_DICT) + save_status = cloud_save_instance( + DUMMY_CONF, + DUMMY_USER, + DUMMY_CLOUD, + DUMMY_LABEL, + DUMMY_INSTANCE_DICT, + ) self.assertTrue(save_status) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, - '%s.state' % DUMMY_CLOUD) + saved_path = os.path.join( + DUMMY_SETTINGS_PATH, DUMMY_USER, "%s.state" % DUMMY_CLOUD + ) self.assertTrue(os.path.exists(saved_path)) - instance = cloud_load_instance(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_LABEL) + instance = cloud_load_instance( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_LABEL + ) # NOTE: instance should be a non-empty dict at this point self.assertTrue(isinstance(instance, dict)) # print(instance) self.assertTrue(DUMMY_INSTANCE_ID in instance) instance_dict = instance[DUMMY_INSTANCE_ID] - self.assertEqual(instance_dict['INSTANCE_LABEL'], DUMMY_LABEL) - self.assertEqual(instance_dict['INSTANCE_IMAGE'], DUMMY_IMAGE) - self.assertEqual(instance_dict['INSTANCE_ID'], DUMMY_INSTANCE_ID) - self.assertEqual(instance_dict['IMAGE_ID'], DUMMY_IMAGE) - self.assertEqual(instance_dict['USER_CERT'], DUMMY_USER) + self.assertEqual(instance_dict["INSTANCE_LABEL"], DUMMY_LABEL) + self.assertEqual(instance_dict["INSTANCE_IMAGE"], DUMMY_IMAGE) + self.assertEqual(instance_dict["INSTANCE_ID"], DUMMY_INSTANCE_ID) + self.assertEqual(instance_dict["IMAGE_ID"], DUMMY_IMAGE) + self.assertEqual(instance_dict["USER_CERT"], DUMMY_USER) - @unittest.skip('Work in progress - currently requires remote openstack') + @unittest.skip("Work in progress - currently requires remote openstack") def test_cloud_allowed_images(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) cleanpath(DUMMY_SETTINGS_DIR, self) - allowed_images = allowed_cloud_images(DUMMY_CONF, DUMMY_USER, - DUMMY_CLOUD, DUMMY_FLAVOR) + allowed_images = allowed_cloud_images( + DUMMY_CONF, DUMMY_USER, DUMMY_CLOUD, DUMMY_FLAVOR + ) self.assertTrue(isinstance(allowed_images, list)) print(allowed_images) self.assertTrue(DUMMY_IMAGE in allowed_images) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_cmdapi.py b/tests/test_mig_shared_cmdapi.py index d5d5b57e0..033bfaeab 100644 --- a/tests/test_mig_shared_cmdapi.py +++ b/tests/test_mig_shared_cmdapi.py @@ -29,12 +29,10 @@ import unittest -# Imports required for the unit test wrapping - # Imports of the code under test from mig.shared.cmdapi import ( - get_flag_map, get_command_map, + get_flag_map, get_usage_map, legacy_main, map_args_to_vars, @@ -44,6 +42,8 @@ # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain +# Imports required for the unit test wrapping + class TestMigSharedCmdapi(MigTestCase): """Unit tests for cmdapi helpers""" @@ -60,106 +60,106 @@ def test_get_command_map_basic(self): """Test basic command map structure""" command_map = get_command_map(self.configuration) self.assertIsInstance(command_map, dict) - self.assertIn('cp', command_map) - self.assertIn('mv', command_map) - self.assertIn('rm', command_map) - self.assertIn('mkdir', command_map) - self.assertIn('du', command_map) + self.assertIn("cp", command_map) + self.assertIn("mv", command_map) + self.assertIn("rm", command_map) + self.assertIn("mkdir", command_map) + self.assertIn("du", command_map) def test_get_command_map_with_jobs_enabled(self): """Test command map includes job commands when enabled""" self.configuration.site_enable_jobs = True command_map = get_command_map(self.configuration) - self.assertIn('submit', command_map) - self.assertIn('canceljob', command_map) - self.assertIn('resubmit', command_map) - self.assertIn('jobaction', command_map) - self.assertIn('liveio', command_map) + self.assertIn("submit", command_map) + self.assertIn("canceljob", command_map) + self.assertIn("resubmit", command_map) + self.assertIn("jobaction", command_map) + self.assertIn("liveio", command_map) def test_get_command_map_with_sharelinks_enabled(self): """Test command map includes sharelink commands when enabled""" self.configuration.site_enable_sharelinks = True command_map = get_command_map(self.configuration) - self.assertIn('delsharelink', command_map) - self.assertIn('importsharelink', command_map) + self.assertIn("delsharelink", command_map) + self.assertIn("importsharelink", command_map) def test_get_command_map_with_transfers_enabled(self): """Test command map includes transfer commands when enabled""" self.configuration.site_enable_transfers = True command_map = get_command_map(self.configuration) - self.assertIn('datatransfer', command_map) + self.assertIn("datatransfer", command_map) def test_get_command_map_with_freeze_enabled(self): """Test command map includes freeze commands when enabled""" self.configuration.site_enable_freeze = True command_map = get_command_map(self.configuration) - self.assertIn('createbackup', command_map) - self.assertIn('deletebackup', command_map) - self.assertIn('addfreezedata', command_map) - self.assertIn('importfreeze', command_map) + self.assertIn("createbackup", command_map) + self.assertIn("deletebackup", command_map) + self.assertIn("addfreezedata", command_map) + self.assertIn("importfreeze", command_map) def test_get_command_map_with_crontab_enabled(self): """Test command map includes crontab commands when enabled""" self.configuration.site_enable_crontab = True command_map = get_command_map(self.configuration) - self.assertIn('crontab', command_map) + self.assertIn("crontab", command_map) def test_get_command_map(self): """Test that get_command_map returns expected command definitions""" cmd_map = get_command_map(self.configuration) # Only a subset is relevant for basic tests expected_subset = { - 'pack': ['src', 'dst'], - 'unpack': ['src', 'dst'], - 'zip': ['src', 'dst'], - 'unzip': ['src', 'dst'], - 'tar': ['src', 'dst'], - 'untar': ['src', 'dst'], - 'cp': ['src', 'dst'], - 'mv': ['src', 'dst'], - 'rm': ['path'], - 'du': ['path', 'dst'], - 'rmdir': ['path'], - 'truncate': ['path'], - 'touch': ['path'], - 'mkdir': ['path'], - 'chksum': ['hash_algo', 'path', 'dst', 'max_chunks'], - 'mqueue': ['queue', 'action', 'msg_id', 'msg'], + "pack": ["src", "dst"], + "unpack": ["src", "dst"], + "zip": ["src", "dst"], + "unzip": ["src", "dst"], + "tar": ["src", "dst"], + "untar": ["src", "dst"], + "cp": ["src", "dst"], + "mv": ["src", "dst"], + "rm": ["path"], + "du": ["path", "dst"], + "rmdir": ["path"], + "truncate": ["path"], + "touch": ["path"], + "mkdir": ["path"], + "chksum": ["hash_algo", "path", "dst", "max_chunks"], + "mqueue": ["queue", "action", "msg_id", "msg"], } for cmd, args in expected_subset.items(): self.assertIn(cmd, cmd_map) - self.assertEqual(cmd_map[cmd][:len(args)], args) + self.assertEqual(cmd_map[cmd][: len(args)], args) def test_get_flag_map_structure(self): """Test flag map structure""" flag_map = get_flag_map(self.configuration) self.assertIsInstance(flag_map, dict) - self.assertIn('cp', flag_map) - self.assertIn('rm', flag_map) - self.assertIn('du', flag_map) - self.assertIn('mkdir', flag_map) - self.assertIn('rmdir', flag_map) + self.assertIn("cp", flag_map) + self.assertIn("rm", flag_map) + self.assertIn("du", flag_map) + self.assertIn("mkdir", flag_map) + self.assertIn("rmdir", flag_map) def test_get_flag_map_values(self): """Test flag map values""" flag_map = get_flag_map(self.configuration) - self.assertEqual(flag_map['cp'], ['r', 'f']) - self.assertEqual(flag_map['rm'], ['r', 'f']) - self.assertEqual(flag_map['du'], ['s']) - self.assertEqual(flag_map['mkdir'], ['p']) - self.assertEqual(flag_map['rmdir'], ['p']) + self.assertEqual(flag_map["cp"], ["r", "f"]) + self.assertEqual(flag_map["rm"], ["r", "f"]) + self.assertEqual(flag_map["du"], ["s"]) + self.assertEqual(flag_map["mkdir"], ["p"]) + self.assertEqual(flag_map["rmdir"], ["p"]) def test_get_flag_map(self): """Test that get_flag_map returns expected flag definitions""" flags = get_flag_map(self.configuration) expected = { - 'cp': ['r', 'f'], - 'rm': ['r', 'f'], - 'du': ['s'], - 'mkdir': ['p'], - 'rmdir': ['p'], - 'importsharelink': ['r', 'f'], - 'importfreeze': ['r', 'f'], + "cp": ["r", "f"], + "rm": ["r", "f"], + "du": ["s"], + "mkdir": ["p"], + "rmdir": ["p"], + "importsharelink": ["r", "f"], + "importfreeze": ["r", "f"], } self.assertEqual(flags, expected) @@ -167,130 +167,128 @@ def test_get_usage_map_structure(self): """Test usage map structure""" usage_map = get_usage_map(self.configuration) self.assertIsInstance(usage_map, dict) - self.assertIn('cp', usage_map) - self.assertIn('mv', usage_map) - self.assertIn('rm', usage_map) - self.assertIn('mkdir', usage_map) - self.assertIn('du', usage_map) + self.assertIn("cp", usage_map) + self.assertIn("mv", usage_map) + self.assertIn("rm", usage_map) + self.assertIn("mkdir", usage_map) + self.assertIn("du", usage_map) def test_get_usage_map_values(self): """Test usage map values""" usage_map = get_usage_map(self.configuration) - self.assertEqual(usage_map['cp'], 'cp [-r] [-f] SRC [SRC ..] DST') - self.assertEqual(usage_map['mv'], 'mv SRC [SRC ..] DST') - self.assertEqual(usage_map['rm'], 'rm [-r] [-f] PATH [PATH ..]') - self.assertEqual(usage_map['mkdir'], 'mkdir [-p] PATH [PATH ..]') - self.assertEqual(usage_map['du'], 'du [-s] PATH [PATH ..] DST') + self.assertEqual(usage_map["cp"], "cp [-r] [-f] SRC [SRC ..] DST") + self.assertEqual(usage_map["mv"], "mv SRC [SRC ..] DST") + self.assertEqual(usage_map["rm"], "rm [-r] [-f] PATH [PATH ..]") + self.assertEqual(usage_map["mkdir"], "mkdir [-p] PATH [PATH ..]") + self.assertEqual(usage_map["du"], "du [-s] PATH [PATH ..] DST") def test_get_usage_map(self): """Test that get_usage_map builds usage strings correctly""" usage = get_usage_map(self.configuration) # Check a known command - self.assertIn('cp', usage) - self.assertIn('[-r]', usage['cp']) - self.assertIn('SRC [SRC ..] DST', usage['cp']) + self.assertIn("cp", usage) + self.assertIn("[-r]", usage["cp"]) + self.assertIn("SRC [SRC ..] DST", usage["cp"]) def test_map_args_to_vars_variable_length(self): """Test that map_args_to_vars expands variable length arguments""" - var_list = ['src', 'dst'] - arg_list = ['a.txt', 'b.txt', 'c.txt'] + var_list = ["src", "dst"] + arg_list = ["a.txt", "b.txt", "c.txt"] result = map_args_to_vars(var_list, arg_list) - self.assertEqual( - result, {'src': ['a.txt', 'b.txt'], 'dst': ['c.txt']} - ) + self.assertEqual(result, {"src": ["a.txt", "b.txt"], "dst": ["c.txt"]}) def test_map_args_to_vars_exact_match(self): """Test map_args_to_vars with exact number of arguments""" - var_list = ['src', 'dst'] - arg_list = ['a.txt', 'b.txt'] + var_list = ["src", "dst"] + arg_list = ["a.txt", "b.txt"] result = map_args_to_vars(var_list, arg_list) - self.assertEqual(result, {'src': ['a.txt'], 'dst': ['b.txt']}) + self.assertEqual(result, {"src": ["a.txt"], "dst": ["b.txt"]}) def test_parse_command_args_basic(self): """Test that parse_command_args parses a simple command correctly""" - cmd_list = ['cp', 'srcfile', 'dstfile'] + cmd_list = ["cp", "srcfile", "dstfile"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'cp') - self.assertEqual(args_dict.get('src'), ['srcfile']) - self.assertEqual(args_dict.get('dst'), ['dstfile']) + self.assertEqual(backend, "cp") + self.assertEqual(args_dict.get("src"), ["srcfile"]) + self.assertEqual(args_dict.get("dst"), ["dstfile"]) def test_parse_command_args_with_flags(self): """Test that parse_command_args handles flags correctly""" - cmd_list = ['cp', '-r', 'srcdir', 'dstdir'] + cmd_list = ["cp", "-r", "srcdir", "dstdir"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'cp') - self.assertIn('flags', args_dict) - self.assertEqual(args_dict['flags'], ['r']) + self.assertEqual(backend, "cp") + self.assertIn("flags", args_dict) + self.assertEqual(args_dict["flags"], ["r"]) def test_parse_command_args_with_multiple_flags(self): """Test that parse_command_args handles multiple combined flags""" - cmd_list = ['cp', '-rf', 'srcdir', 'dstdir'] + cmd_list = ["cp", "-rf", "srcdir", "dstdir"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'cp') - self.assertIn('flags', args_dict) - self.assertEqual(args_dict['flags'], ['rf']) + self.assertEqual(backend, "cp") + self.assertIn("flags", args_dict) + self.assertEqual(args_dict["flags"], ["rf"]) def test_parse_command_args_unsupported(self): """Test that parse_command_args raises on unsupported command""" - cmd_list = ['unknown_cmd', 'arg1'] + cmd_list = ["unknown_cmd", "arg1"] with self.assertRaises(ValueError) as cm: parse_command_args(self.configuration, cmd_list) - self.assertIn('unsupported command', str(cm.exception)) + self.assertIn("unsupported command", str(cm.exception)) def test_parse_command_args_delsharelink_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_sharelinks = True - cmd_list = ['delsharelink', 'share123'] + cmd_list = ["delsharelink", "share123"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'delsharelink') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('share_id'), ['share123']) + self.assertEqual(backend, "delsharelink") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("share_id"), ["share123"]) def test_parse_command_args_canceljob_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_jobs = True - cmd_list = ['canceljob', 'job123'] + cmd_list = ["canceljob", "job123"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'canceljob') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('job_id'), ['job123']) + self.assertEqual(backend, "canceljob") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("job_id"), ["job123"]) def test_parse_command_args_datatransfer_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_transfers = True - cmd_list = ['datatransfer', 'transfer123'] + cmd_list = ["datatransfer", "transfer123"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'datatransfer') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('transfer_id'), ['transfer123']) + self.assertEqual(backend, "datatransfer") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("transfer_id"), ["transfer123"]) def test_parse_command_args_deletebackup_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_freeze = True - cmd_list = ['deletebackup', 'backup123'] + cmd_list = ["deletebackup", "backup123"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'deletebackup') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('freeze_id'), ['backup123']) + self.assertEqual(backend, "deletebackup") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("freeze_id"), ["backup123"]) def test_parse_command_args_crontab_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_crontab = True - cmd_list = ['crontab', 'transfer123', 'reschedule'] + cmd_list = ["crontab", "transfer123", "reschedule"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'crontab') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('action'), ['reschedule']) + self.assertEqual(backend, "crontab") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("action"), ["reschedule"]) def test_parse_command_args_mqueue_no_flags_entry(self): """Regression: ensure commands without flags don't produce flags key""" self.configuration.site_enable_jobs = True - cmd_list = ['mqueue', 'testqueue', 'msgaction', 'msgid', 'test msg'] + cmd_list = ["mqueue", "testqueue", "msgaction", "msgid", "test msg"] backend, args_dict = parse_command_args(self.configuration, cmd_list) - self.assertEqual(backend, 'mqueue') - self.assertNotIn('flags', args_dict) - self.assertEqual(args_dict.get('queue'), ['testqueue']) - self.assertEqual(args_dict.get('msg'), ['test msg']) + self.assertEqual(backend, "mqueue") + self.assertNotIn("flags", args_dict) + self.assertEqual(args_dict.get("queue"), ["testqueue"]) + self.assertEqual(args_dict.get("msg"), ["test msg"]) class TestMigSharedCmdapi__legacy_main(MigTestCase): @@ -318,5 +316,5 @@ def record_last_print(value): legacy_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_compat.py b/tests/test_mig_shared_compat.py index 05f2de2bc..faad707f7 100644 --- a/tests/test_mig_shared_compat.py +++ b/tests/test_mig_shared_compat.py @@ -31,13 +31,13 @@ import os import sys +from mig.shared.compat import PY2, ensure_native_string from tests.support import MigTestCase, testmain -from mig.shared.compat import PY2, ensure_native_string +DUMMY_BYTECHARS = b"DEADBEEF" +DUMMY_BYTESRAW = binascii.unhexlify("DEADBEEF") # 4 bytes +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" -DUMMY_BYTECHARS = b'DEADBEEF' -DUMMY_BYTESRAW = binascii.unhexlify('DEADBEEF') # 4 bytes -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' class MigSharedCompat__ensure_native_string(MigTestCase): """Unit test helper for the migrid code pointed to in class name""" @@ -45,7 +45,7 @@ class MigSharedCompat__ensure_native_string(MigTestCase): def test_char_bytes_conversion(self): actual = ensure_native_string(DUMMY_BYTECHARS) self.assertIs(type(actual), str) - self.assertEqual(actual, 'DEADBEEF') + self.assertEqual(actual, "DEADBEEF") def test_raw_bytes_conversion(self): with self.assertRaises(UnicodeDecodeError): @@ -60,5 +60,5 @@ def test_unicode_conversion(self): self.assertEqual(actual, DUMMY_UNICODE) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_configuration.py b/tests/test_mig_shared_configuration.py index 1108ea3e0..c02b33559 100644 --- a/tests/test_mig_shared_configuration.py +++ b/tests/test_mig_shared_configuration.py @@ -31,16 +31,23 @@ import os import unittest -from tests.support import MigTestCase, TEST_DATA_DIR, PY2, testmain +from mig.shared.configuration import ( + _CONFIGURATION_ARGUMENTS, + _CONFIGURATION_PROPERTIES, + Configuration, +) +from tests.support import PY2, TEST_DATA_DIR, MigTestCase, testmain from tests.support.fixturesupp import FixtureAssertMixin -from mig.shared.configuration import Configuration, \ - _CONFIGURATION_ARGUMENTS, _CONFIGURATION_PROPERTIES - def _to_dict(obj): - return {k: v for k, v in inspect.getmembers(obj) - if not (k.startswith('__') or inspect.ismethod(v) or inspect.isfunction(v))} + return { + k: v + for k, v in inspect.getmembers(obj) + if not ( + k.startswith("__") or inspect.ismethod(v) or inspect.isfunction(v) + ) + } class MigSharedConfiguration__static_definitions(MigTestCase): @@ -50,8 +57,9 @@ def test_consistent_parameters(self): configuration_defaults_keys = set(_CONFIGURATION_PROPERTIES.keys()) mismatched = _CONFIGURATION_ARGUMENTS - configuration_defaults_keys - self.assertEqual(len(mismatched), 0, - "configuration defaults do not match arguments") + self.assertEqual( + len(mismatched), 0, "configuration defaults do not match arguments" + ) class MigSharedConfiguration__loaded_configurations(MigTestCase): @@ -59,19 +67,23 @@ class MigSharedConfiguration__loaded_configurations(MigTestCase): def test_argument_new_user_default_ui_is_replaced(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.new_user_default_ui, 'V3') + self.assertEqual(configuration.new_user_default_ui, "V3") def test_argument_storage_protocols(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # TODO: add a test to cover filtering of a mix of valid+invalid protos # self.assertEqual(configuration.storage_protocols, ['xxx', 'yyy', 'zzz']) @@ -81,90 +93,110 @@ def test_argument_storage_protocols(self): def test_argument_wwwserve_max_bytes(self): test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.wwwserve_max_bytes, 43211234) def test_argument_include_sections(self): """Test that include_sections path default is set""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised.conf') + TEST_DATA_DIR, "MiGserver--customised.conf" + ) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - '/home/mig/mig/server/MiGserver.d') + self.assertEqual( + configuration.include_sections, "/home/mig/mig/server/MiGserver.d" + ) def test_argument_custom_include_sections(self): """Test that include_sections path override is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") self.assertTrue(os.path.isdir(test_conf_section_dir)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) - self.assertEqual(configuration.include_sections, - test_conf_section_dir) + self.assertEqual(configuration.include_sections, test_conf_section_dir) def test_argument_include_sections_quota(self): """Test that QUOTA conf section overrides are correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'quota.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "quota.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.quota_backend, 'dummy') + self.assertEqual(configuration.quota_backend, "dummy") self.assertEqual(configuration.quota_user_limit, 4242) self.assertEqual(configuration.quota_vgrid_limit, 4242424242) def test_argument_include_sections_cloud_misty(self): """Test that CLOUD_MISTY conf section is correctly applied""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'cloud_misty.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "cloud_misty.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.cloud_services, list) self.assertTrue(configuration.cloud_services) self.assertIsInstance(configuration.cloud_services[0], dict) - self.assertTrue(configuration.cloud_services[0].get('service_name', - False)) - self.assertEqual(configuration.cloud_services[0]['service_name'], - 'MISTY') - self.assertEqual(configuration.cloud_services[0]['service_desc'], - 'MISTY service') - self.assertEqual(configuration.cloud_services[0]['service_provider_flavor'], - 'nostack') + self.assertTrue( + configuration.cloud_services[0].get("service_name", False) + ) + self.assertEqual( + configuration.cloud_services[0]["service_name"], "MISTY" + ) + self.assertEqual( + configuration.cloud_services[0]["service_desc"], "MISTY service" + ) + self.assertEqual( + configuration.cloud_services[0]["service_provider_flavor"], + "nostack", + ) def test_argument_include_sections_global_accepted(self): """Test that peripheral GLOBAL conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.admin_email, "admin@somewhere.org") @@ -176,93 +208,105 @@ def test_argument_include_sections_global_accepted(self): def test_argument_include_sections_global_rejected(self): """Test that core GLOBAL conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'global.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "global.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Run through the snippet values and check that override didn't succeed # and then that default is left set. The former _could_ be left out but # is kept explicit for clarity in case something breaks by changes. - self.assertNotEqual(configuration.include_sections, '/tmp/MiGserver.d') + self.assertNotEqual(configuration.include_sections, "/tmp/MiGserver.d") self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertNotEqual(configuration.mig_path, '/tmp/mig/mig') - self.assertEqual(configuration.mig_path, '/home/mig/mig') - self.assertNotEqual(configuration.logfile, '/tmp/mig.log') - self.assertEqual(configuration.logfile, 'mig.log') - self.assertNotEqual(configuration.loglevel, 'warning') - self.assertEqual(configuration.loglevel, 'info') - self.assertNotEqual(configuration.server_fqdn, 'somewhere.org') - self.assertEqual(configuration.server_fqdn, '') - self.assertNotEqual(configuration.migserver_public_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_public_url, '') - self.assertNotEqual(configuration.migserver_https_sid_url, - 'https://somewhere.org') - self.assertEqual(configuration.migserver_https_sid_url, '') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertNotEqual(configuration.user_openid_address, 'somewhere.org') - self.assertEqual(configuration.user_openid_address, '') + self.assertNotEqual(configuration.mig_path, "/tmp/mig/mig") + self.assertEqual(configuration.mig_path, "/home/mig/mig") + self.assertNotEqual(configuration.logfile, "/tmp/mig.log") + self.assertEqual(configuration.logfile, "mig.log") + self.assertNotEqual(configuration.loglevel, "warning") + self.assertEqual(configuration.loglevel, "info") + self.assertNotEqual(configuration.server_fqdn, "somewhere.org") + self.assertEqual(configuration.server_fqdn, "") + self.assertNotEqual( + configuration.migserver_public_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_public_url, "") + self.assertNotEqual( + configuration.migserver_https_sid_url, "https://somewhere.org" + ) + self.assertEqual(configuration.migserver_https_sid_url, "") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertNotEqual(configuration.user_openid_address, "somewhere.org") + self.assertEqual(configuration.user_openid_address, "") self.assertNotEqual(configuration.user_openid_port, 4242) self.assertEqual(configuration.user_openid_port, 8443) - self.assertNotEqual(configuration.user_openid_key, '/tmp/openid.key') - self.assertEqual(configuration.user_openid_key, '') - self.assertNotEqual(configuration.user_openid_log, '/tmp/openid.log') - self.assertEqual(configuration.user_openid_log, - '/home/mig/state/log/openid.log') + self.assertNotEqual(configuration.user_openid_key, "/tmp/openid.key") + self.assertEqual(configuration.user_openid_key, "") + self.assertNotEqual(configuration.user_openid_log, "/tmp/openid.log") + self.assertEqual( + configuration.user_openid_log, "/home/mig/state/log/openid.log" + ) def test_argument_include_sections_site_accepted(self): """Test that peripheral SITE conf overrides are accepted (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) - self.assertEqual(configuration.short_title, 'ACME Site') - self.assertEqual(configuration.new_user_default_ui, 'V3') - self.assertEqual(configuration.site_password_legacy_policy, 'MEDIUM') - self.assertEqual(configuration.site_support_text, - 'Custom support text') - self.assertEqual(configuration.site_privacy_text, - 'Custom privacy text') - self.assertEqual(configuration.site_peers_notice, - 'Custom peers notice') - self.assertEqual(configuration.site_peers_contact_hint, - 'Custom peers contact hint') + self.assertEqual(configuration.short_title, "ACME Site") + self.assertEqual(configuration.new_user_default_ui, "V3") + self.assertEqual(configuration.site_password_legacy_policy, "MEDIUM") + self.assertEqual(configuration.site_support_text, "Custom support text") + self.assertEqual(configuration.site_privacy_text, "Custom privacy text") + self.assertEqual(configuration.site_peers_notice, "Custom peers notice") + self.assertEqual( + configuration.site_peers_contact_hint, "Custom peers contact hint" + ) self.assertIsInstance(configuration.site_freeze_admins, list) self.assertTrue(len(configuration.site_freeze_admins) == 1) - self.assertTrue('BOFH' in configuration.site_freeze_admins) - self.assertEqual(configuration.site_freeze_to_tape, - 'Custom freeze to tape') - self.assertEqual(configuration.site_freeze_doi_text, - 'Custom freeze doi text') - self.assertEqual(configuration.site_freeze_doi_url, - 'https://somewhere.org/mint-doi') - self.assertEqual(configuration.site_freeze_doi_url_field, - 'archiveurl') + self.assertTrue("BOFH" in configuration.site_freeze_admins) + self.assertEqual( + configuration.site_freeze_to_tape, "Custom freeze to tape" + ) + self.assertEqual( + configuration.site_freeze_doi_text, "Custom freeze doi text" + ) + self.assertEqual( + configuration.site_freeze_doi_url, "https://somewhere.org/mint-doi" + ) + self.assertEqual(configuration.site_freeze_doi_url_field, "archiveurl") def test_argument_include_sections_site_rejected(self): """Test that core SITE conf overrides are rejected (policy)""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'site.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "site.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertEqual(configuration.site_enable_openid, False) @@ -279,56 +323,63 @@ def test_argument_include_sections_site_rejected(self): def test_argument_include_sections_with_invalid_conf_filename(self): """Test that conf snippet with missing .conf extension gets ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'dummy') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join(test_conf_section_dir, "dummy") self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_with_section_name_mismatch(self): """Test that conf section must match filename""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'section-mismatch.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "section-mismatch.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf only contains SETTINGS section which is ignored due to mismatch self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Pig Latin' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Pig Latin" in configuration.language) + self.assertEqual(configuration.language, ["English"]) def test_argument_include_sections_multi_ignores_other_sections(self): """Test that conf section must match filename and others are ignored""" test_conf_file = os.path.join( - TEST_DATA_DIR, 'MiGserver--customised-include_sections.conf') - test_conf_section_dir = os.path.join('tests', 'data', 'MiGserver.d') - test_conf_section_file = os.path.join(test_conf_section_dir, - 'multi.conf') + TEST_DATA_DIR, "MiGserver--customised-include_sections.conf" + ) + test_conf_section_dir = os.path.join("tests", "data", "MiGserver.d") + test_conf_section_file = os.path.join( + test_conf_section_dir, "multi.conf" + ) self.assertTrue(os.path.isfile(test_conf_section_file)) configuration = Configuration( - test_conf_file, skip_log=True, disable_auth_log=True) + test_conf_file, skip_log=True, disable_auth_log=True + ) # Conf contains MULTI and SETTINGS sections and latter must be ignored self.assertEqual(configuration.include_sections, test_conf_section_dir) self.assertIsInstance(configuration.language, list) - self.assertFalse('Spanglish' in configuration.language) - self.assertEqual(configuration.language, ['English']) + self.assertFalse("Spanglish" in configuration.language) + self.assertEqual(configuration.language, ["English"]) # TODO: rename file to valid section name we can check and enable next? # self.assertEqual(configuration.multi, 'blabla') @@ -339,15 +390,16 @@ class MigSharedConfiguration__new_instance(MigTestCase, FixtureAssertMixin): @unittest.skipIf(PY2, "Python 3 only") def test_default_object(self): prepared_fixture = self.prepareFixtureAssert( - 'mig_shared_configuration--new', fixture_format='json') + "mig_shared_configuration--new", fixture_format="json" + ) configuration = Configuration(None) # TODO: the following work-around default values set for these on the # instance that no longer make total sense but fiddling with them # is better as a follow-up. - configuration.certs_path = '/some/place/certs' - configuration.state_path = '/some/place/state' - configuration.mig_path = '/some/place/mig' + configuration.certs_path = "/some/place/certs" + configuration.state_path = "/some/place/state" + configuration.mig_path = "/some/place/mig" actual_values = _to_dict(configuration) @@ -358,11 +410,11 @@ def test_object_isolation(self): configuration_2 = Configuration(None) # change one of the configuration objects - configuration_1.default_page.append('foobar') + configuration_1.default_page.append("foobar") # check the other was not affected - self.assertEqual(configuration_2.default_page, ['']) + self.assertEqual(configuration_2.default_page, [""]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_fileio.py b/tests/test_mig_shared_fileio.py index f43f013e8..277da3a7c 100644 --- a/tests/test_mig_shared_fileio.py +++ b/tests/test_mig_shared_fileio.py @@ -35,52 +35,53 @@ # Imports of the code under test import mig.shared.fileio as fileio + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -DUMMY_BYTES = binascii.unhexlify('DEADBEEF') # 4 bytes +DUMMY_BYTES = binascii.unhexlify("DEADBEEF") # 4 bytes DUMMY_BYTES_LENGTH = 4 -DUMMY_UNICODE = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' +DUMMY_UNICODE = "UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®" DUMMY_UNICODE_LENGTH = len(DUMMY_UNICODE) -DUMMY_TEXT = 'dummy' -DUMMY_TWICE = 'dummy - dummy' -DUMMY_TESTDIR = 'fileio' -DUMMY_SUBDIR = 'subdir' -DUMMY_FILE_ONE = 'file1.txt' -DUMMY_FILE_TWO = 'file2.txt' -DUMMY_FILE_MISSING = 'missing.txt' -DUMMY_FILE_RO = 'readonly.txt' -DUMMY_FILE_WO = 'writeonly.txt' -DUMMY_FILE_RW = 'readwrite.txt' -DUMMY_DIRECTORY_NESTED = 'nested/dir/structure' -DUMMY_DIRECTORY_EMPTY = 'empty_dir' -DUMMY_DIRECTORY_MOVE_SRC = 'move_dir_src' -DUMMY_DIRECTORY_MOVE_DST = 'move_dir_dst' -DUMMY_DIRECTORY_REMOVE = 'remove_dir' -DUMMY_DIRECTORY_CHECKACCESS = 'check_access' -DUMMY_DIRECTORY_MAKEDIRSREC = 'makedirs_rec' -DUMMY_DIRECTORY_COPYRECSRC = 'copy_dir_src' -DUMMY_DIRECTORY_COPYRECDST = 'copy_dir_dst' -DUMMY_DIRECTORY_REMOVEREC = 'remove_rec' +DUMMY_TEXT = "dummy" +DUMMY_TWICE = "dummy - dummy" +DUMMY_TESTDIR = "fileio" +DUMMY_SUBDIR = "subdir" +DUMMY_FILE_ONE = "file1.txt" +DUMMY_FILE_TWO = "file2.txt" +DUMMY_FILE_MISSING = "missing.txt" +DUMMY_FILE_RO = "readonly.txt" +DUMMY_FILE_WO = "writeonly.txt" +DUMMY_FILE_RW = "readwrite.txt" +DUMMY_DIRECTORY_NESTED = "nested/dir/structure" +DUMMY_DIRECTORY_EMPTY = "empty_dir" +DUMMY_DIRECTORY_MOVE_SRC = "move_dir_src" +DUMMY_DIRECTORY_MOVE_DST = "move_dir_dst" +DUMMY_DIRECTORY_REMOVE = "remove_dir" +DUMMY_DIRECTORY_CHECKACCESS = "check_access" +DUMMY_DIRECTORY_MAKEDIRSREC = "makedirs_rec" +DUMMY_DIRECTORY_COPYRECSRC = "copy_dir_src" +DUMMY_DIRECTORY_COPYRECDST = "copy_dir_dst" +DUMMY_DIRECTORY_REMOVEREC = "remove_rec" # File/dir paths for move/copy operations -DUMMY_FILE_MOVE_SRC = 'move_src' -DUMMY_FILE_MOVE_DST = 'move_dst' -DUMMY_FILE_COPY_SRC = 'copy_src' -DUMMY_FILE_COPY_DST = 'copy_dst' -DUMMY_FILE_WRITECHUNK = 'write_chunk' -DUMMY_FILE_WRITEFILE = 'write_file' -DUMMY_FILE_WRITEFILELINES = 'write_file_lines' -DUMMY_FILE_READFILE = 'read_file' -DUMMY_FILE_READFILELINES = 'read_file_lines' -DUMMY_FILE_READHEADLINES = 'read_head_lines' -DUMMY_FILE_READTAILLINES = 'read_tail_lines' -DUMMY_FILE_DELETEFILE = 'delete_file' -DUMMY_FILE_GETFILESIZE = 'get_file_size' -DUMMY_FILE_MAKESYMLINKSRC = 'link_src' -DUMMY_FILE_MAKESYMLINKDST = 'link_target' -DUMMY_FILE_DELETESYMLINKSRC = 'link_src' -DUMMY_FILE_DELETESYMLINKDST = 'link_target' -DUMMY_FILE_TOUCH = 'touch_file' +DUMMY_FILE_MOVE_SRC = "move_src" +DUMMY_FILE_MOVE_DST = "move_dst" +DUMMY_FILE_COPY_SRC = "copy_src" +DUMMY_FILE_COPY_DST = "copy_dst" +DUMMY_FILE_WRITECHUNK = "write_chunk" +DUMMY_FILE_WRITEFILE = "write_file" +DUMMY_FILE_WRITEFILELINES = "write_file_lines" +DUMMY_FILE_READFILE = "read_file" +DUMMY_FILE_READFILELINES = "read_file_lines" +DUMMY_FILE_READHEADLINES = "read_head_lines" +DUMMY_FILE_READTAILLINES = "read_tail_lines" +DUMMY_FILE_DELETEFILE = "delete_file" +DUMMY_FILE_GETFILESIZE = "get_file_size" +DUMMY_FILE_MAKESYMLINKSRC = "link_src" +DUMMY_FILE_MAKESYMLINKDST = "link_target" +DUMMY_FILE_DELETESYMLINKSRC = "link_src" +DUMMY_FILE_DELETESYMLINKDST = "link_target" +DUMMY_FILE_TOUCH = "touch_file" assert isinstance(DUMMY_BYTES, bytes) @@ -90,12 +91,13 @@ class MigSharedFileio__temporary_umask(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_file_one = os.path.join(self.tmp_base, DUMMY_FILE_ONE) try: @@ -121,63 +123,63 @@ def before_each(self): def test_creates_new_file_with_temporary_umask_777(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o777): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o000) def test_creates_new_file_with_temporary_umask_277(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o277): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o400) def test_creates_new_file_with_temporary_umask_227(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o227): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o440) def test_creates_new_file_with_temporary_umask_077(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o077): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o600) def test_creates_new_file_with_temporary_umask_027(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o027): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o640) def test_creates_new_file_with_temporary_umask_007(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o007): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o660) def test_creates_new_file_with_temporary_umask_022(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o022): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o644) def test_creates_new_file_with_temporary_umask_002(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o002): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o664) def test_creates_new_file_with_temporary_umask_000(self): """Test create file with permissions restricted by given temp umask""" with fileio.temporary_umask(0o000): - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() self.assertTrue(os.path.isfile(self.tmp_file_one)) self.assertEqual(os.stat(self.tmp_file_one).st_mode & 0o777, 0o666) @@ -257,8 +259,9 @@ def test_restores_original_umask_after_exit(self): current_umask = os.umask(original_umask) # Retrieve and reset # Cleanup: Restore environment os.umask(current_umask) - self.assertEqual(current_umask, 0o022, - "Failed to restore original umask") + self.assertEqual( + current_umask, 0o022, "Failed to restore original umask" + ) finally: os.umask(original_umask) # Ensure cleanup @@ -267,20 +270,20 @@ def test_nested_temporary_umask(self): original_umask = os.umask(0o022) try: with fileio.temporary_umask(0o027): # Outer context - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1 = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1, 0o640) # 666 & ~027 = 640 with fileio.temporary_umask(0o077): # Inner context - open(self.tmp_file_two, 'w').close() + open(self.tmp_file_two, "w").close() mode2 = os.stat(self.tmp_file_two).st_mode & 0o777 self.assertEqual(mode2, 0o600) # 666 & ~077 # Back to outer context umask os.remove(self.tmp_file_one) - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode1_after = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode1_after, 0o640) # 666 & ~027 # Back to original umask - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() mode_original = os.stat(self.tmp_file_one).st_mode & 0o777 self.assertEqual(mode_original, 0o640) # 666 & ~002 finally: @@ -309,7 +312,7 @@ def test_restores_umask_after_exception(self): def test_umask_does_not_affect_existing_files(self): """Test temporary_umask doesn't modify existing file permissions""" - open(self.tmp_file_one, 'w').close() + open(self.tmp_file_one, "w").close() os.chmod(self.tmp_file_one, 0o644) # Explicit permissions with fileio.temporary_umask(0o077): # Shouldn't affect existing file # Change permissions inside context @@ -324,12 +327,13 @@ class MigSharedFileio__write_chunk(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_WRITECHUNK) @@ -338,16 +342,18 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_chunk(self.tmp_path, 1234, 0, self.logger, - force_string=False) + did_succeed = fileio.write_chunk( + self.tmp_path, 1234, 0, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_offset(self): """Test write_chunk returns False with negative offset value""" self.logger.forgive_errors() - did_succeed = fileio.write_chunk(self.tmp_path, DUMMY_BYTES, -42, - self.logger) + did_succeed = fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, -42, self.logger + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): @@ -368,7 +374,7 @@ def test_store_bytes(self): """Test write_chunk stores byte data correctly at offset 0""" fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) @@ -379,42 +385,52 @@ def test_store_bytes_at_offset(self): fileio.write_chunk(self.tmp_path, DUMMY_BYTES, offset, self.logger) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH + offset) - self.assertEqual(content[0:3], bytearray([0, 0, 0]), - "expected a hole was left") + self.assertEqual( + content[0:3], bytearray([0, 0, 0]), "expected a hole was left" + ) self.assertEqual(content[3:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_chunk stores byte data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_BYTES, 0, self.logger, - mode="r+") + fileio.write_chunk( + self.tmp_path, DUMMY_BYTES, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_chunk stores unicode data in text mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_chunk stores unicode data in binary mode""" - fileio.write_chunk(self.tmp_path, DUMMY_UNICODE, 0, self.logger, - mode='r+b') + fileio.write_chunk( + self.tmp_path, DUMMY_UNICODE, 0, self.logger, mode="r+b" + ) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -425,12 +441,13 @@ class MigSharedFileio__write_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -441,24 +458,25 @@ def test_return_false_on_invalid_data(self): self.logger.forgive_errors() # NOTE: we make sure to disable any forced stringification here - did_succeed = fileio.write_file(1234, self.tmp_path, self.logger, - force_string=False) + did_succeed = fileio.write_file( + 1234, self.tmp_path, self.logger, force_string=False + ) self.assertFalse(did_succeed) def test_return_false_on_invalid_dir(self): """Test write_file returns False when path is a directory""" self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger) + did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger) self.assertFalse(did_succeed) def test_return_false_on_missing_dir(self): """Test write_file returns False on missing parent dir""" self.logger.forgive_errors() - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, - self.logger, make_parent=False) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(did_succeed) def test_creates_directory(self): @@ -466,7 +484,7 @@ def test_creates_directory(self): # TODO: temporarily use empty string to avoid any byte/unicode issues # did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, # self.logger) - did_succeed = fileio.write_file('', self.tmp_path, self.logger) + did_succeed = fileio.write_file("", self.tmp_path, self.logger) self.assertTrue(did_succeed) path_kind = self.assertPathExists(self.tmp_path) @@ -475,49 +493,59 @@ def test_creates_directory(self): # TODO: replace next test once we have auto adjust mode in write helper def test_store_bytes_with_manual_adjust_mode(self): """Test write_file stores byte data in with manual adjust mode call""" - mode = 'w' + mode = "w" mode = fileio._auto_adjust_mode(DUMMY_BYTES, mode) - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode=mode) + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode=mode + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_bytes_in_text_mode(self): """Test write_file stores byte data when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_BYTES, self.tmp_path, self.logger, - mode="w") + did_succeed = fileio.write_file( + DUMMY_BYTES, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'rb') as file: + with open(self.tmp_path, "rb") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_BYTES_LENGTH) self.assertEqual(content[:], DUMMY_BYTES) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode(self): """Test write_file stores unicode string when opening in text mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='w') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="w" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) - @unittest.skip("TODO: enable again - requires the temporarily disabled auto mode select") + @unittest.skip( + "TODO: enable again - requires the temporarily disabled auto mode select" + ) def test_store_unicode_in_binary_mode(self): """Test write_file handles unicode strings when opening in binary mode""" - did_succeed = fileio.write_file(DUMMY_UNICODE, self.tmp_path, - self.logger, mode='wb') + did_succeed = fileio.write_file( + DUMMY_UNICODE, self.tmp_path, self.logger, mode="wb" + ) self.assertTrue(did_succeed) - with open(self.tmp_path, 'r') as file: + with open(self.tmp_path, "r") as file: content = file.read(1024) self.assertEqual(len(content), DUMMY_UNICODE_LENGTH) self.assertEqual(content[:], DUMMY_UNICODE) @@ -528,12 +556,13 @@ class MigSharedFileio__write_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) # NOTE: we inject sub-directory to test with missing and existing self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) @@ -542,8 +571,7 @@ def before_each(self): def test_write_lines(self): """Test write_file_lines writes lines to a file""" test_lines = ["line1\n", "line2\n", "line3"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) # Verify with read_file_lines @@ -559,8 +587,7 @@ def test_invalid_data(self): def test_creates_directory(self): """Test write_file_lines creates parent directory when needed""" test_lines = ["test line"] - result = fileio.write_file_lines( - test_lines, self.tmp_path, self.logger) + result = fileio.write_file_lines(test_lines, self.tmp_path, self.logger) self.assertTrue(result) path_kind = self.assertPathExists(self.tmp_path) @@ -571,14 +598,16 @@ def test_return_false_on_invalid_dir(self): self.logger.forgive_errors() ensure_dirs_exist(self.tmp_path) result = fileio.write_file_lines( - [DUMMY_TEXT], self.tmp_path, self.logger) + [DUMMY_TEXT], self.tmp_path, self.logger + ) self.assertFalse(result) def test_return_false_on_missing_dir(self): """Test write_file_lines fails when parent directory missing""" self.logger.forgive_errors() - result = fileio.write_file_lines([DUMMY_TEXT], self.tmp_path, self.logger, - make_parent=False) + result = fileio.write_file_lines( + [DUMMY_TEXT], self.tmp_path, self.logger, make_parent=False + ) self.assertFalse(result) @@ -587,40 +616,43 @@ class MigSharedFileio__read_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILE) def test_reads_bytes(self): """Test read_file returns byte content with binary mode""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) - content = fileio.read_file(self.tmp_path, self.logger, mode='rb') + content = fileio.read_file(self.tmp_path, self.logger, mode="rb") self.assertEqual(content, DUMMY_BYTES) def test_reads_text(self): """Test read_file returns text with text mode""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_UNICODE) - content = fileio.read_file(self.tmp_path, self.logger, mode='r') + content = fileio.read_file(self.tmp_path, self.logger, mode="r") self.assertEqual(content, DUMMY_UNICODE) def test_allows_missing_file(self): """Test read_file returns None with allow_missing=True""" content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertIsNone(content) def test_reports_missing_file(self): """Test read_file returns None with allow_missing=False""" self.logger.forgive_errors() content = fileio.read_file( - 'missing.txt', self.logger, allow_missing=False) + "missing.txt", self.logger, allow_missing=False + ) self.assertIsNone(content) def test_handles_directory_path(self): @@ -636,31 +668,32 @@ class MigSharedFileio__read_file_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READFILELINES) def test_returns_empty_list_for_empty_file(self): """Test read_file_lines returns empty list for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, []) def test_reads_lines_from_file(self): """Test read_file_lines returns lines from text file""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3") lines = fileio.read_file_lines(self.tmp_path, self.logger) self.assertEqual(lines, ["line1\n", "line2\n", "line3"]) def test_none_for_missing_file(self): self.logger.forgive_errors() - lines = fileio.read_file_lines('missing.txt', self.logger) + lines = fileio.read_file_lines("missing.txt", self.logger) self.assertIsNone(lines) @@ -669,18 +702,19 @@ class MigSharedFileio__get_file_size(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_GETFILESIZE) def test_returns_file_size(self): """Test get_file_size returns correct file size""" - with open(self.tmp_path, 'wb') as fh: + with open(self.tmp_path, "wb") as fh: fh.write(DUMMY_BYTES) size = fileio.get_file_size(self.tmp_path, self.logger) self.assertEqual(size, DUMMY_BYTES_LENGTH) @@ -688,7 +722,7 @@ def test_returns_file_size(self): def test_handles_missing_file(self): """Test get_file_size returns -1 for missing file""" self.logger.forgive_errors() - size = fileio.get_file_size('missing.txt', self.logger) + size = fileio.get_file_size("missing.txt", self.logger) self.assertEqual(size, -1) def test_handles_directory(self): @@ -707,18 +741,19 @@ class MigSharedFileio__delete_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_DELETEFILE) def test_deletes_existing_file(self): """Test delete_file removes existing file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() result = fileio.delete_file(self.tmp_path, self.logger) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_path)) @@ -726,15 +761,16 @@ def test_deletes_existing_file(self): def test_handles_missing_file_with_allow_missing(self): """Test delete_file succeeds with allow_missing=True""" result = fileio.delete_file( - 'missing.txt', self.logger, allow_missing=True) + "missing.txt", self.logger, allow_missing=True + ) self.assertTrue(result) def test_false_for_missing_file_without_allow_missing(self): """Test delete_file returns False with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_file('missing.txt', - self.logger, - allow_missing=False) + result = fileio.delete_file( + "missing.txt", self.logger, allow_missing=False + ) self.assertFalse(result) @@ -743,39 +779,40 @@ class MigSharedFileio__read_head_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READHEADLINES) def test_reads_requested_lines(self): """Test read_head_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_head_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line1\n", "line2\n"]) def test_returns_all_lines_when_requested_more(self): """Test read_head_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_head_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_head_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_head_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_head_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_head_lines('missing.txt', 3, self.logger) + lines = fileio.read_head_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -784,39 +821,40 @@ class MigSharedFileio__read_tail_lines(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_READTAILLINES) def test_reads_requested_lines(self): """Test read_tail_lines returns requested number of lines""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2\nline3\nline4") lines = fileio.read_tail_lines(self.tmp_path, 2, self.logger) self.assertEqual(lines, ["line3\n", "line4"]) def test_returns_all_lines_when_requested_more(self): """Test read_tail_lines returns all lines when file has fewer""" - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write("line1\nline2") lines = fileio.read_tail_lines(self.tmp_path, 5, self.logger) self.assertEqual(lines, ["line1\n", "line2"]) def test_returns_empty_list_for_empty_file(self): """Test read_tail_lines returns empty for empty file""" - open(self.tmp_path, 'w').close() + open(self.tmp_path, "w").close() lines = fileio.read_tail_lines(self.tmp_path, 3, self.logger) self.assertEqual(lines, []) def test_empty_for_missing_file(self): """Test read_tail_lines returns [] for missing file""" self.logger.forgive_errors() - lines = fileio.read_tail_lines('missing.txt', 3, self.logger) + lines = fileio.read_tail_lines("missing.txt", 3, self.logger) self.assertEqual(lines, []) @@ -825,49 +863,52 @@ class MigSharedFileio__make_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_MAKESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_target = os.path.join(self.tmp_dir, DUMMY_FILE_MAKESYMLINKDST) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def test_creates_symlink(self): """Test make_symlink creates working symlink""" result = fileio.make_symlink( - self.tmp_target, self.tmp_link, self.logger) + self.tmp_target, self.tmp_link, self.logger + ) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_force_overwrites_existing_link(self): """Test make_symlink force replaces existing link""" - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, - self.logger, force=True) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=True + ) self.assertTrue(result) self.assertEqual(os.readlink(self.tmp_link), self.tmp_target) def test_fails_on_existing_link_without_force(self): """Test make_symlink fails on existing link without force""" self.logger.forgive_errors() - os.symlink('/dummy', self.tmp_link) - result = fileio.make_symlink(self.tmp_target, self.tmp_link, self.logger, - force=False) + os.symlink("/dummy", self.tmp_link) + result = fileio.make_symlink( + self.tmp_target, self.tmp_link, self.logger, force=False + ) self.assertFalse(result) def test_handles_nonexistent_target(self): """Test make_symlink still creates broken symlink""" self.logger.forgive_errors() - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" result = fileio.make_symlink(broken_target, self.tmp_link, self.logger) self.assertTrue(result) self.assertTrue(os.path.islink(self.tmp_link)) @@ -879,20 +920,21 @@ class MigSharedFileio__delete_symlink(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_dir = os.path.join(self.tmp_base, DUMMY_SUBDIR) ensure_dirs_exist(self.tmp_dir) - self.tmp_link = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKSRC) - self.tmp_target = os.path.join(self.tmp_dir, - DUMMY_FILE_DELETESYMLINKDST) - with open(self.tmp_target, 'w') as fh: + self.tmp_link = os.path.join(self.tmp_dir, DUMMY_FILE_DELETESYMLINKSRC) + self.tmp_target = os.path.join( + self.tmp_dir, DUMMY_FILE_DELETESYMLINKDST + ) + with open(self.tmp_target, "w") as fh: fh.write(DUMMY_TEXT) def create_symlink(self, target=None, link=None): @@ -915,33 +957,36 @@ def test_handles_missing_file_with_allow_missing(self): # First make sure file doesn't exist if os.path.exists(self.tmp_link): os.remove(self.tmp_link) - result = fileio.delete_symlink(self.tmp_link, self.logger, - allow_missing=True) + result = fileio.delete_symlink( + self.tmp_link, self.logger, allow_missing=True + ) self.assertTrue(result) def test_handles_missing_symlink_without_allow_missing(self): """Test delete_symlink fails with allow_missing=False""" self.logger.forgive_errors() - result = fileio.delete_symlink('missing_symlink', self.logger, - allow_missing=False) + result = fileio.delete_symlink( + "missing_symlink", self.logger, allow_missing=False + ) self.assertFalse(result) @unittest.skip("TODO: implement check in tested function and enable again") def test_rejects_regular_file(self): """Test delete_symlink returns False when path is a regular file""" - with open(self.tmp_link, 'w') as fh: + with open(self.tmp_link, "w") as fh: fh.write(DUMMY_TEXT) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.delete_symlink(self.tmp_link, self.logger) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) def test_deletes_broken_symlink(self): """Test delete_symlink removes broken symlink""" # Create broken symlink - broken_target = self.tmp_target + '-nonexistent' + broken_target = self.tmp_target + "-nonexistent" self.create_symlink(broken_target) self.assertTrue(os.path.islink(self.tmp_link)) # Now delete it @@ -954,12 +999,13 @@ class MigSharedFileio__touch(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_FILE_TOUCH) @@ -971,11 +1017,13 @@ def test_creates_new_file(self): self.assertTrue(os.path.exists(self.tmp_path)) self.assertTrue(os.path.isfile(self.tmp_path)) - @unittest.skip("TODO: fix invalid open 'r+w' in tested function and enable again") + @unittest.skip( + "TODO: fix invalid open 'r+w' in tested function and enable again" + ) def test_updates_timestamp_on_existing_file(self): """Test touch updates timestamp on existing file""" # Create initial file - with open(self.tmp_path, 'w') as fh: + with open(self.tmp_path, "w") as fh: fh.write(DUMMY_TEXT) orig_mtime = os.path.getmtime(self.tmp_path) time.sleep(0.1) @@ -984,7 +1032,9 @@ def test_updates_timestamp_on_existing_file(self): new_mtime = os.path.getmtime(self.tmp_path) self.assertNotEqual(orig_mtime, new_mtime) - @unittest.skip("TODO: fix handling of directory in tested function and enable again") + @unittest.skip( + "TODO: fix handling of directory in tested function and enable again" + ) def test_succeeds_on_directory(self): """Test touch succeeds for existing directory and updates timestamp""" ensure_dirs_exist(self.tmp_path) @@ -999,7 +1049,7 @@ def test_succeeds_on_directory(self): def test_fails_on_missing_parent(self): """Test touch fails when parent directory doesn't exist""" self.logger.forgive_errors() - nested_path = os.path.join(self.tmp_path, 'missing', DUMMY_FILE_ONE) + nested_path = os.path.join(self.tmp_path, "missing", DUMMY_FILE_ONE) result = fileio.touch(nested_path, self.configuration) self.assertFalse(result) self.assertFalse(os.path.exists(nested_path)) @@ -1010,12 +1060,13 @@ class MigSharedFileio__remove_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) # NOTE: we prepare tmp_path as directory here @@ -1032,7 +1083,7 @@ def test_fails_on_nonempty_directory(self): """Test remove_dir returns False for non-empty directory""" self.logger.forgive_errors() # Add a file to the directory - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(self.tmp_path, self.configuration) self.assertFalse(result) @@ -1043,7 +1094,7 @@ def test_fails_on_file(self): self.logger.forgive_errors() # Add a file to the directory file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.remove_dir(file_path, self.configuration) self.assertFalse(result) @@ -1055,12 +1106,13 @@ class MigSharedFileio__remove_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVEREC) # Create a nested directory structure with files @@ -1069,10 +1121,11 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_path, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_path, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_path, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_removes_directory_recursively(self): @@ -1085,7 +1138,7 @@ def test_removes_directory_recursively(self): def test_removes_directory_recursively_with_symlink(self): """Test remove_rec removes directory and contents with symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_ONE + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1095,7 +1148,7 @@ def test_removes_directory_recursively_with_symlink(self): def test_removes_directory_recursively_with_broken_symlink(self): """Test remove_rec removes directory and contents with broken symlink""" link_src = os.path.join(self.tmp_path, DUMMY_FILE_MISSING) - link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + '.lnk') + link_dst = os.path.join(self.tmp_path, DUMMY_FILE_MISSING + ".lnk") os.symlink(link_src, link_dst) self.assertTrue(os.path.exists(self.tmp_path)) result = fileio.remove_rec(self.tmp_path, self.configuration) @@ -1115,11 +1168,12 @@ def test_removes_directory_recursively_despite_readonly(self): def test_rejects_regular_file(self): """Test remove_rec returns False when path is a regular file""" file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with self.assertLogs(level='ERROR') as log_capture: + with self.assertLogs(level="ERROR") as log_capture: result = fileio.remove_rec(file_path, self.configuration) self.assertFalse(result) - self.assertTrue(any('Could not remove' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Could not remove" in msg for msg in log_capture.output) + ) self.assertTrue(os.path.exists(file_path)) @@ -1128,22 +1182,24 @@ class MigSharedFileio__move_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_MOVE_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_moves_file(self): """Test move_file successfully moves a file""" - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) self.assertFalse(os.path.exists(self.tmp_src)) @@ -1152,13 +1208,14 @@ def test_moves_file(self): def test_overwrites_existing_destination(self): """Test move_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) - success, msg = fileio.move_file(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_file( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1168,12 +1225,13 @@ class MigSharedFileio__move_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_REMOVE) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MOVE_SRC) @@ -1184,43 +1242,54 @@ def before_each(self): # └── subdir/ # └── file2.txt ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_moves_directory_recursively(self): """Test move_rec moves directory and contents""" - result = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + result = fileio.move_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertFalse(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, DUMMY_SUBDIR, - DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) def test_extends_existing_destination(self): """Test move_rec extends existing destination directory""" # Create initial destination with some content ensure_dirs_exist(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) - success, msg = fileio.move_rec(self.tmp_src, self.tmp_dst, - self.configuration) + success, msg = fileio.move_rec( + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(success) self.assertFalse(msg) # Verify structure with new src subdir and existing dir new_sub = os.path.basename(DUMMY_DIRECTORY_MOVE_SRC) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join(self.tmp_dst, new_sub, - DUMMY_SUBDIR, - DUMMY_FILE_TWO))) - self.assertTrue(os.path.exists( - os.path.join(self.tmp_dst, DUMMY_TESTDIR))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, new_sub, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join( + self.tmp_dst, new_sub, DUMMY_SUBDIR, DUMMY_FILE_TWO + ) + ) + ) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_TESTDIR)) + ) class MigSharedFileio__copy_file(MigTestCase): @@ -1228,23 +1297,25 @@ class MigSharedFileio__copy_file(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_FILE_COPY_SRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_FILE_COPY_DST) - with open(self.tmp_src, 'w') as fh: + with open(self.tmp_src, "w") as fh: fh.write(DUMMY_TEXT) def test_copies_file(self): """Test copy_file successfully copies a file""" result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) @@ -1252,12 +1323,13 @@ def test_copies_file(self): def test_overwrites_existing_destination(self): """Test copy_file overwrites existing destination file""" # Create initial destination file - with open(self.tmp_dst, 'w') as fh: + with open(self.tmp_dst, "w") as fh: fh.write(DUMMY_TWICE) result = fileio.copy_file( - self.tmp_src, self.tmp_dst, self.configuration) + self.tmp_src, self.tmp_dst, self.configuration + ) self.assertTrue(result) - with open(self.tmp_dst, 'r') as fh: + with open(self.tmp_dst, "r") as fh: content = fh.read() self.assertEqual(content, DUMMY_TEXT) @@ -1267,36 +1339,41 @@ class MigSharedFileio__copy_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.tmp_src = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECSRC) self.tmp_dst = os.path.join(self.tmp_base, DUMMY_DIRECTORY_COPYRECDST) # Create a nested directory structure with files ensure_dirs_exist(self.tmp_src) ensure_dirs_exist(os.path.join(self.tmp_src, DUMMY_SUBDIR)) - with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.tmp_src, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) - with open(os.path.join(self.tmp_src, DUMMY_SUBDIR, - DUMMY_FILE_TWO), 'w') as fh: + with open( + os.path.join(self.tmp_src, DUMMY_SUBDIR, DUMMY_FILE_TWO), "w" + ) as fh: fh.write(DUMMY_TWICE) def test_copies_directory_recursively(self): """Test copy_rec copies directory and contents""" - result = fileio.copy_rec( - self.tmp_src, self.tmp_dst, self.configuration) + result = fileio.copy_rec(self.tmp_src, self.tmp_dst, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(self.tmp_src)) self.assertTrue(os.path.exists(self.tmp_dst)) # Verify structure - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_FILE_ONE))) - self.assertTrue(os.path.exists(os.path.join( - self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO))) + self.assertTrue( + os.path.exists(os.path.join(self.tmp_dst, DUMMY_FILE_ONE)) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dst, DUMMY_SUBDIR, DUMMY_FILE_TWO) + ) + ) class MigSharedFileio__check_empty_dir(MigTestCase): @@ -1304,20 +1381,20 @@ class MigSharedFileio__check_empty_dir(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) self.empty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_EMPTY) - self.nonempty_path = os.path.join( - self.tmp_base, DUMMY_DIRECTORY_NESTED) + self.nonempty_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_NESTED) ensure_dirs_exist(self.empty_path) # Create non-empty directory structure ensure_dirs_exist(self.nonempty_path) - with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), 'w') as fh: + with open(os.path.join(self.nonempty_path, DUMMY_FILE_ONE), "w") as fh: fh.write(DUMMY_TEXT) def test_returns_true_for_empty(self): @@ -1340,20 +1417,21 @@ class MigSharedFileio__makedirs_rec(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_path = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_MAKEDIRSREC) + self.tmp_path = os.path.join(self.tmp_base, DUMMY_DIRECTORY_MAKEDIRSREC) def test_creates_directory_path(self): """Test makedirs_rec creates nested directories""" - nested_path = os.path.join(self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, - DUMMY_TESTDIR) + nested_path = os.path.join( + self.tmp_path, DUMMY_TESTDIR, DUMMY_SUBDIR, DUMMY_TESTDIR + ) result = fileio.makedirs_rec(nested_path, self.configuration) self.assertTrue(result) self.assertTrue(os.path.exists(nested_path)) @@ -1370,7 +1448,7 @@ def test_fails_for_file_path(self): # Create a file at the path ensure_dirs_exist(self.tmp_path) file_path = os.path.join(self.tmp_path, DUMMY_FILE_ONE) - with open(file_path, 'w') as fh: + with open(file_path, "w") as fh: fh.write(DUMMY_TEXT) result = fileio.makedirs_rec(file_path, self.configuration) self.assertFalse(result) @@ -1381,26 +1459,26 @@ class MigSharedFileio__check_access(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def before_each(self): """Setup test environment before each test method""" - self.tmp_base = os.path.join(self.configuration.mig_system_run, - DUMMY_TESTDIR) + self.tmp_base = os.path.join( + self.configuration.mig_system_run, DUMMY_TESTDIR + ) ensure_dirs_exist(self.tmp_base) - self.tmp_dir = os.path.join(self.tmp_base, - DUMMY_DIRECTORY_CHECKACCESS) + self.tmp_dir = os.path.join(self.tmp_base, DUMMY_DIRECTORY_CHECKACCESS) ensure_dirs_exist(self.tmp_dir) self.writeonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_WO) self.readonly_file = os.path.join(self.tmp_dir, DUMMY_FILE_RO) self.readwrite_file = os.path.join(self.tmp_dir, DUMMY_FILE_RW) # Create test files with different permissions - with open(self.writeonly_file, 'w') as fh: + with open(self.writeonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readonly_file, 'w') as fh: + with open(self.readonly_file, "w") as fh: fh.write(DUMMY_TEXT) - with open(self.readwrite_file, 'w') as fh: + with open(self.readwrite_file, "w") as fh: fh.write(DUMMY_TEXT) # Set permissions @@ -1414,14 +1492,13 @@ def test_check_read_access_file(self): """Test check_read_access with readable file""" self.assertTrue(fileio.check_read_access(self.readwrite_file)) self.assertTrue(fileio.check_read_access(self.readonly_file)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, - parent_dir=True)) + self.assertTrue(fileio.check_read_access(self.tmp_dir, parent_dir=True)) # Super-user has access to read and write all files! if os.getuid() == 0: self.assertTrue(fileio.check_read_access(self.writeonly_file)) else: self.assertFalse(fileio.check_read_access(self.writeonly_file)) - self.assertFalse(fileio.check_read_access('/invalid/path')) + self.assertFalse(fileio.check_read_access("/invalid/path")) def test_check_write_access_file(self): """Test check_write_access with writable file""" @@ -1432,7 +1509,7 @@ def test_check_write_access_file(self): self.assertTrue(fileio.check_write_access(self.readonly_file)) else: self.assertFalse(fileio.check_write_access(self.readonly_file)) - self.assertFalse(fileio.check_write_access('/invalid/path')) + self.assertFalse(fileio.check_write_access("/invalid/path")) def test_check_read_access_with_parent(self): """Test check_read_access with parent_dir True""" @@ -1448,78 +1525,108 @@ def test_check_write_access_with_parent(self): def test_check_readable(self): """Test check_readable wrapper function""" - self.assertTrue(fileio.check_readable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_readable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_readable(self.configuration, self.readonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readable(self.configuration, self.writeonly_file) + ) else: - self.assertFalse(fileio.check_readable(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readable(self.configuration, - '/invalid/path')) + self.assertFalse( + fileio.check_readable(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readable(self.configuration, "/invalid/path") + ) def test_check_writable(self): """Test check_writable wrapper function""" - self.assertTrue(fileio.check_writable(self.configuration, - self.readwrite_file)) - self.assertTrue(fileio.check_writable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readwrite_file) + ) + self.assertTrue( + fileio.check_writable(self.configuration, self.writeonly_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_writable(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_writable(self.configuration, self.readonly_file) + ) else: - self.assertFalse(fileio.check_writable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_writable(self.configuration, - "/no/such/file")) + self.assertFalse( + fileio.check_writable(self.configuration, self.readonly_file) + ) + self.assertFalse( + fileio.check_writable(self.configuration, "/no/such/file") + ) def test_check_readonly(self): """Test check_readonly wrapper function""" # Super-user has access to read and write all files! if os.getuid() == 0: # Test with read-only file path - self.assertFalse(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) else: # Test with read-only file path - self.assertTrue(fileio.check_readonly(self.configuration, - self.readonly_file)) + self.assertTrue( + fileio.check_readonly(self.configuration, self.readonly_file) + ) # Test with writable file - self.assertFalse(fileio.check_readonly(self.configuration, - self.writeonly_file)) - self.assertFalse(fileio.check_readonly(self.configuration, - self.readwrite_file)) + self.assertFalse( + fileio.check_readonly(self.configuration, self.writeonly_file) + ) + self.assertFalse( + fileio.check_readonly(self.configuration, self.readwrite_file) + ) def test_check_readwritable(self): """Test check_readwritable wrapper function""" - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readwrite_file)) + self.assertTrue( + fileio.check_readwritable(self.configuration, self.readwrite_file) + ) # Super-user has access to read and write all files! if os.getuid() == 0: - self.assertTrue(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertTrue(fileio.check_readwritable(self.configuration, - self.writeonly_file)) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertTrue( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) else: - self.assertFalse(fileio.check_readwritable(self.configuration, - self.readonly_file)) - self.assertFalse(fileio.check_readwritable(self.configuration, - self.writeonly_file)) - - self.assertFalse(fileio.check_readwritable(self.configuration, - "/invalid/file")) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.readonly_file + ) + ) + self.assertFalse( + fileio.check_readwritable( + self.configuration, self.writeonly_file + ) + ) + + self.assertFalse( + fileio.check_readwritable(self.configuration, "/invalid/file") + ) def test_special_cases(self): """Test various special cases for access checks""" @@ -1533,11 +1640,13 @@ def test_special_cases(self): self.assertFalse(fileio.check_write_access(missing_path)) # Check with custom follow_symlink=False - self.assertTrue(fileio.check_read_access(self.readwrite_file, - follow_symlink=False)) - self.assertTrue(fileio.check_read_access(self.tmp_dir, True, - follow_symlink=False)) + self.assertTrue( + fileio.check_read_access(self.readwrite_file, follow_symlink=False) + ) + self.assertTrue( + fileio.check_read_access(self.tmp_dir, True, follow_symlink=False) + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_filemarks.py b/tests/test_mig_shared_filemarks.py index a640deba9..41a3a95c8 100644 --- a/tests/test_mig_shared_filemarks.py +++ b/tests/test_mig_shared_filemarks.py @@ -36,11 +36,12 @@ # Imports of the code under test from mig.shared.filemarks import get_filemark, reset_filemark, update_filemark + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_MARKS_DIR = 'TestMarks' -TEST_MARKS_FILE = 'file.mark' +TEST_MARKS_DIR = "TestMarks" +TEST_MARKS_FILE = "file.mark" class TestMigSharedFilemarks(MigTestCase): @@ -48,7 +49,7 @@ class TestMigSharedFilemarks(MigTestCase): def _provide_configuration(self): """Set up isolated test configuration and logger for the tests""" - return 'testconfig' + return "testconfig" def _prepare_mark_for_test(self, mark_name=None, timestamp=None): """Prepare test for mark_name with timestamp in default location""" @@ -57,7 +58,7 @@ def _prepare_mark_for_test(self, mark_name=None, timestamp=None): if timestamp is None: timestamp = time.time() self.marks_path = os.path.join(self.marks_base, mark_name) - open(self.marks_path, 'w').close() + open(self.marks_path, "w").close() os.utime(self.marks_path, (timestamp, timestamp)) return timestamp @@ -69,8 +70,9 @@ def _verify_mark_after_test(self, mark_name, timestamp): def before_each(self): """Setup fake configuration and temp dir before each test.""" - self.marks_base = os.path.join(self.configuration.mig_system_run, - TEST_MARKS_DIR) + self.marks_base = os.path.join( + self.configuration.mig_system_run, TEST_MARKS_DIR + ) ensure_dirs_exist(self.marks_base) self.marks_path = os.path.join(self.marks_base, TEST_MARKS_FILE) @@ -78,8 +80,9 @@ def test_update_filemark_create(self): """Test update_filemark creates mark file with timestamp""" timestamp = 4242 self.assertFalse(os.path.isfile(self.marks_path)) - update_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(update_result) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -89,8 +92,9 @@ def test_update_filemark_timestamp(self): timestamp = 424242 self._prepare_mark_for_test(TEST_MARKS_FILE, 4242) - update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, timestamp) + update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, timestamp + ) self.assertTrue(os.path.isfile(self.marks_path)) self.assertEqual(os.path.getmtime(self.marks_path), timestamp) @@ -98,8 +102,9 @@ def test_update_filemark_delete(self): """Test update_filemark deletes mark files with negative timestamp""" self._prepare_mark_for_test(TEST_MARKS_FILE) - delete_result = update_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE, -1) + delete_result = update_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE, -1 + ) self.assertTrue(delete_result) self.assertFalse(os.path.exists(self.marks_path)) @@ -108,23 +113,26 @@ def test_get_filemark_existing(self): timestamp = 4242 self._prepare_mark_for_test(TEST_MARKS_FILE, timestamp) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertEqual(retrieved, timestamp) def test_get_filemark_missing(self): """Test get_filemark returns None for missing mark files""" self.assertFalse(os.path.isfile(self.marks_path)) - retrieved = get_filemark(self.configuration, self.marks_base, - 'missing.mark') + retrieved = get_filemark( + self.configuration, self.marks_base, "missing.mark" + ) self.assertIsNone(retrieved) def test_reset_filemark_single(self): """Test reset_filemark updates single mark timestamp to 0""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE]) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE] + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) @@ -133,18 +141,20 @@ def test_reset_filemark_delete(self): """Test reset_filemark deletes marks with delete=True""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - [TEST_MARKS_FILE], delete=True) + reset_result = reset_filemark( + self.configuration, self.marks_base, [TEST_MARKS_FILE], delete=True + ) self.assertTrue(reset_result) - retrieved = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + retrieved = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(retrieved) self.assertFalse(os.path.exists(self.marks_path)) def test_reset_filemark_all(self): """Test reset_filemark resets all marks when mark_list=None""" - marks = ['mark1', 'mark2', 'mark3'] + marks = ["mark1", "mark2", "mark3"] for mark in marks: self._prepare_mark_for_test(mark) @@ -157,15 +167,22 @@ def test_reset_filemark_all(self): def test_update_filemark_fails_when_file_prevents_directory(self): """Test update_filemark fails when file prevents create directory""" # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') - - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('obstruct', 'test.mark'), - time.time()) + self._prepare_mark_for_test("obstruct") + + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("obstruct", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_update_filemark_directory_perms_failure(self): @@ -173,13 +190,17 @@ def test_update_filemark_directory_perms_failure(self): # Create a read-only parent directory to prevent subdir creation os.chmod(self.marks_base, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - result = update_filemark(self.configuration, self.marks_base, - os.path.join('noaccess', 'test.mark'), - time.time()) + with self.assertLogs(level="ERROR") as log_capture: + result = update_filemark( + self.configuration, + self.marks_base, + os.path.join("noaccess", "test.mark"), + time.time(), + ) self.assertFalse(result) - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_get_filemark_permission_denied(self): @@ -188,8 +209,9 @@ def test_get_filemark_permission_denied(self): # Remove read permissions through parent dir os.chmod(self.marks_base, 0) - result = get_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + result = get_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertIsNone(result) # Restore permissions so cleanup works os.chmod(self.marks_base, stat.S_IRWXU) @@ -198,20 +220,23 @@ def test_reset_filemark_string_mark_list(self): """Test reset_filemark handles single string mark_list""" self._prepare_mark_for_test(TEST_MARKS_FILE) - reset_result = reset_filemark(self.configuration, self.marks_base, - TEST_MARKS_FILE) + reset_result = reset_filemark( + self.configuration, self.marks_base, TEST_MARKS_FILE + ) self.assertTrue(reset_result) self._verify_mark_after_test(TEST_MARKS_FILE, 0) def test_reset_filemark_invalid_mark_list(self): """Test reset_filemark fails with invalid mark_list type""" - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - {'invalid': 'type'}) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, {"invalid": "type"} + ) self.assertFalse(reset_result) - self.assertTrue(any('invalid mark list' in msg for msg in - log_capture.output)) + self.assertTrue( + any("invalid mark list" in msg for msg in log_capture.output) + ) def test_reset_filemark_all_missing_dir(self): """Test reset_filemark handles missing directory when mark_list=None""" @@ -222,41 +247,48 @@ def test_reset_filemark_all_missing_dir(self): @unittest.skipIf(os.getuid() == 0, "access check is ignored as priv user") def test_reset_filemark_partial_perms_failure(self): """Test reset_filemark with partial failure due to permissions""" - valid_mark = 'valid.mark' - invalid_mark = 'invalid.mark' + valid_mark = "valid.mark" + invalid_mark = "invalid.mark" invalid_path = os.path.join(self.marks_base, invalid_mark) # Create both marks but remove access to the latter self._prepare_mark_for_test(valid_mark) self._prepare_mark_for_test(invalid_mark) os.chmod(invalid_path, stat.S_IRUSR) # Remove write permissions - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('Permission denied' in msg for msg in - log_capture.output)) + self.assertTrue( + any("Permission denied" in msg for msg in log_capture.output) + ) self._verify_mark_after_test(valid_mark, 0) def test_reset_filemark_partial_file_prevents_directory_failure(self): """Test reset_filemark with partial failure due to a file in the way""" - valid_mark = 'valid.mark' - invalid_mark = os.path.join('obstruct', 'invalid.mark') + valid_mark = "valid.mark" + invalid_mark = os.path.join("obstruct", "invalid.mark") # Create valid mark and a file to prevent the invalid mark self._prepare_mark_for_test(valid_mark) # Create a file in the way to prevent subdir creation - self._prepare_mark_for_test('obstruct') + self._prepare_mark_for_test("obstruct") - with self.assertLogs(level='ERROR') as log_capture: - reset_result = reset_filemark(self.configuration, self.marks_base, - [valid_mark, invalid_mark]) + with self.assertLogs(level="ERROR") as log_capture: + reset_result = reset_filemark( + self.configuration, self.marks_base, [valid_mark, invalid_mark] + ) self.assertFalse(reset_result) # Should fail due to partial failure - self.assertTrue(any('in the way' in msg or 'could not create' in msg - for msg in log_capture.output)) + self.assertTrue( + any( + "in the way" in msg or "could not create" in msg + for msg in log_capture.output + ) + ) self._verify_mark_after_test(valid_mark, 0) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_cat.py b/tests/test_mig_shared_functionality_cat.py index 619c3dd4e..a8081c105 100644 --- a/tests/test_mig_shared_functionality_cat.py +++ b/tests/test_mig_shared_functionality_cat.py @@ -28,17 +28,25 @@ """Unit tests of the MiG functionality file implementing the cat backend""" from __future__ import print_function + import importlib import os import shutil import sys import unittest -from tests.support import MIG_BASE, PY2, TEST_DATA_DIR, MigTestCase, testmain, \ - temppath, ensure_dirs_exist - from mig.shared.base import client_id_dir -from mig.shared.functionality.cat import _main as submain, main as realmain +from mig.shared.functionality.cat import _main as submain +from mig.shared.functionality.cat import main as realmain +from tests.support import ( + MIG_BASE, + PY2, + TEST_DATA_DIR, + MigTestCase, + ensure_dirs_exist, + temppath, + testmain, +) def create_http_environ(configuration): @@ -47,139 +55,169 @@ def create_http_environ(configuration): """ environ = {} - environ['MIG_CONF'] = configuration.config_file - environ['HTTP_HOST'] = 'localhost' - environ['PATH_INFO'] = '/' - environ['REMOTE_ADDR'] = '127.0.0.1' - environ['SCRIPT_URI'] = ''.join(('https://', environ['HTTP_HOST'], - environ['PATH_INFO'])) + environ["MIG_CONF"] = configuration.config_file + environ["HTTP_HOST"] = "localhost" + environ["PATH_INFO"] = "/" + environ["REMOTE_ADDR"] = "127.0.0.1" + environ["SCRIPT_URI"] = "".join( + ("https://", environ["HTTP_HOST"], environ["PATH_INFO"]) + ) return environ def _only_output_objects(output_objects, with_object_type=None): - return [o for o in output_objects if o['object_type'] == with_object_type] + return [o for o in output_objects if o["object_type"] == with_object_type] class MigSharedFunctionalityCat(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def assertSingleOutputObject(self, output_objects, with_object_type=None): assert with_object_type is not None - found_objects = _only_output_objects(output_objects, - with_object_type=with_object_type) + found_objects = _only_output_objects( + output_objects, with_object_type=with_object_type + ) self.assertEqual(len(found_objects), 1) return found_objects[0] def test_file_serving_a_single_file_match(self): - with open(os.path.join(self.test_user_dir, 'foobar.txt'), 'w'): + with open(os.path.join(self.test_user_dir, "foobar.txt"), "w"): pass payload = { - 'path': ['foobar.txt'], + "path": ["foobar.txt"], } - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual content self.assertEqual(len(output_objects), 2) - self.assertSingleOutputObject(output_objects, - with_object_type='file_output') + self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) def test_file_serving_at_limit(self): test_binary_file = os.path.realpath( - os.path.join(TEST_DATA_DIR, 'loading.gif')) + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join( - self.test_user_dir, 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } self.configuration.wwwserve_max_bytes = test_binary_file_size - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='file_output') - self.assertEqual(len(relevant_obj['lines']), 1) - self.assertEqual(relevant_obj['lines'][0], test_binary_file_data) + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="file_output" + ) + self.assertEqual(len(relevant_obj["lines"]), 1) + self.assertEqual(relevant_obj["lines"][0], test_binary_file_data) def test_file_serving_over_limit_without_storage_protocols(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } # NOTE: override default storage_protocols to empty in this test self.configuration.storage_protocols = [] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message self.assertEqual(len(output_objects), 2) - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes", + ) def test_file_serving_over_limit_with_storage_protocols_sftp(self): - test_binary_file = os.path.realpath(os.path.join(TEST_DATA_DIR, - 'loading.gif')) + test_binary_file = os.path.realpath( + os.path.join(TEST_DATA_DIR, "loading.gif") + ) test_binary_file_size = os.stat(test_binary_file).st_size - with open(test_binary_file, 'rb') as fh_test_file: + with open(test_binary_file, "rb") as fh_test_file: test_binary_file_data = fh_test_file.read() - shutil.copyfile(test_binary_file, os.path.join(self.test_user_dir, - 'loading.gif')) + shutil.copyfile( + test_binary_file, os.path.join(self.test_user_dir, "loading.gif") + ) payload = { - 'output_format': ['file'], - 'path': ['loading.gif'], + "output_format": ["file"], + "path": ["loading.gif"], } - self.configuration.storage_protocols = ['sftp'] + self.configuration.storage_protocols = ["sftp"] self.configuration.wwwserve_max_bytes = test_binary_file_size - 1 - (output_objects, status) = submain(self.configuration, self.logger, - client_id=self.TEST_CLIENT_ID, - user_arguments_dict=payload, - environ=self.test_environ) + output_objects, status = submain( + self.configuration, + self.logger, + client_id=self.TEST_CLIENT_ID, + user_arguments_dict=payload, + environ=self.test_environ, + ) # NOTE: start entry with headers and actual error message - relevant_obj = self.assertSingleOutputObject(output_objects, - with_object_type='error_text') - self.assertEqual(relevant_obj['text'], - "Site configuration prevents web serving contents " - "bigger than 3896 bytes - please use better " - "alternatives (SFTP) to retrieve large data") + relevant_obj = self.assertSingleOutputObject( + output_objects, with_object_type="error_text" + ) + self.assertEqual( + relevant_obj["text"], + "Site configuration prevents web serving contents " + "bigger than 3896 bytes - please use better " + "alternatives (SFTP) to retrieve large data", + ) @unittest.skipIf(PY2, "Python 3 only") def test_main_passes_environ(self): @@ -187,17 +225,21 @@ def test_main_passes_environ(self): result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) except Exception as unexpectedexc: raise AssertionError( - "saw unexpected exception: %s" % (unexpectedexc,)) + "saw unexpected exception: %s" % (unexpectedexc,) + ) - (output_objects, status) = result - self.assertEqual(status[1], 'Client error') + output_objects, status = result + self.assertEqual(status[1], "Client error") - error_text_objects = _only_output_objects(output_objects, - with_object_type='error_text') + error_text_objects = _only_output_objects( + output_objects, with_object_type="error_text" + ) relevant_obj = error_text_objects[2] self.assertEqual( - relevant_obj['text'], 'Input arguments were rejected - not allowed for this script!') + relevant_obj["text"], + "Input arguments were rejected - not allowed for this script!", + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_functionality_datatransfer.py b/tests/test_mig_shared_functionality_datatransfer.py index 3fe9e5fdc..6887fd5a1 100644 --- a/tests/test_mig_shared_functionality_datatransfer.py +++ b/tests/test_mig_shared_functionality_datatransfer.py @@ -28,18 +28,19 @@ """Unit tests of the MiG functionality file implementing the datatransfer backend""" from __future__ import print_function + import os import mig.shared.returnvalues as returnvalues -from mig.shared.defaults import CSRF_MINIMAL from mig.shared.base import client_id_dir -from mig.shared.functionality.datatransfer import _main as submain, main as realmain - +from mig.shared.defaults import CSRF_MINIMAL +from mig.shared.functionality.datatransfer import _main as submain +from mig.shared.functionality.datatransfer import main as realmain from tests.support import ( MigTestCase, - testmain, - temppath, ensure_dirs_exist, + temppath, + testmain, ) @@ -66,25 +67,27 @@ def _only_output_objects(output_objects, with_object_type=None): class MigSharedFunctionalityDataTransfer(MigTestCase): """Wrap unit tests for the corresponding module""" - TEST_CLIENT_ID = ( - "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" - ) + TEST_CLIENT_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" def _provide_configuration(self): return "testconfig" def before_each(self): - self.test_user_dir = self._provision_test_user(self, self.TEST_CLIENT_ID) + self.test_user_dir = self._provision_test_user( + self, self.TEST_CLIENT_ID + ) self.test_environ = create_http_environ(self.configuration) def test_default_disabled_site_transfer(self): self.assertFalse(self.configuration.site_enable_transfers) result = realmain(self.TEST_CLIENT_ID, {}, self.test_environ) - (output_objects, status) = result + output_objects, status = result self.assertEqual(status, returnvalues.OK) - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 1) self.assertIn("text", text_objects[0]) text_object = text_objects[0]["text"] @@ -95,7 +98,7 @@ def test_show_action_enabled_site_transfer(self): payload = {"action": ["show"]} self.configuration.site_enable_transfers = True - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -105,17 +108,22 @@ def test_show_action_enabled_site_transfer(self): self.assertEqual(status, returnvalues.OK) # We don't expect any text messages here - text_objects = _only_output_objects(output_objects, with_object_type="text") + text_objects = _only_output_objects( + output_objects, with_object_type="text" + ) self.assertEqual(len(text_objects), 0) def test_deltransfer_without_transfer_id(self): non_existing_transfer_id = "non-existing-transfer-id" - payload = {"action": ["deltransfer"], "transfer_id": [non_existing_transfer_id]} + payload = { + "action": ["deltransfer"], + "transfer_id": [non_existing_transfer_id], + } self.configuration.site_enable_transfers = True self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, @@ -129,7 +137,8 @@ def test_deltransfer_without_transfer_id(self): ) self.assertEqual(len(error_text_objects), 1) self.assertEqual( - error_text_objects[0]["text"], "existing transfer_id is required for delete" + error_text_objects[0]["text"], + "existing transfer_id is required for delete", ) def test_redotransfer_without_transfer_id(self): @@ -142,7 +151,7 @@ def test_redotransfer_without_transfer_id(self): self.configuration.site_csrf_protection = CSRF_MINIMAL self.test_environ["REQUEST_METHOD"] = "post" - (output_objects, status) = submain( + output_objects, status = submain( self.configuration, self.logger, client_id=self.TEST_CLIENT_ID, diff --git a/tests/test_mig_shared_install.py b/tests/test_mig_shared_install.py index b06f3c928..e67793879 100644 --- a/tests/test_mig_shared_install.py +++ b/tests/test_mig_shared_install.py @@ -27,21 +27,28 @@ """Unit tests for the migrid module pointed to in the filename""" -from past.builtins import basestring import binascii -from configparser import ConfigParser, NoSectionError, NoOptionError import difflib import io import os import pwd import sys +from configparser import ConfigParser, NoOptionError, NoSectionError -from tests.support import MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, \ - testmain, temppath, cleanpath, is_path_within -from tests.support.fixturesupp import fixturepath +from past.builtins import basestring from mig.shared.defaults import keyword_auto from mig.shared.install import determine_timezone, generate_confs +from tests.support import ( + MIG_BASE, + TEST_OUTPUT_DIR, + MigTestCase, + cleanpath, + is_path_within, + temppath, + testmain, +) +from tests.support.fixturesupp import fixturepath class DummyPwInfo: @@ -70,40 +77,43 @@ class MigSharedInstall__determine_timezone(MigTestCase): def test_determines_tz_utc_fallback(self): timezone = determine_timezone( - _environ={}, _path_exists=lambda _: False, _print=noop) + _environ={}, _path_exists=lambda _: False, _print=noop + ) - self.assertEqual(timezone, 'UTC') + self.assertEqual(timezone, "UTC") def test_determines_tz_via_environ(self): - example_environ = { - 'TZ': 'Example/Enviromnent' - } + example_environ = {"TZ": "Example/Enviromnent"} timezone = determine_timezone(_environ=example_environ) - self.assertEqual(timezone, 'Example/Enviromnent') + self.assertEqual(timezone, "Example/Enviromnent") def test_determines_tz_via_localtime(self): def exists_localtime(value): - saw_call = value == '/etc/localtime' + saw_call = value == "/etc/localtime" exists_localtime.was_called = saw_call return saw_call + exists_localtime.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_localtime) + _environ={}, _path_exists=exists_localtime + ) self.assertTrue(exists_localtime.was_called) self.assertIsNotNone(timezone) def test_determines_tz_via_timedatectl(self): def exists_timedatectl(value): - saw_call = value == '/usr/bin/timedatectl' + saw_call = value == "/usr/bin/timedatectl" exists_timedatectl.was_called = saw_call return saw_call + exists_timedatectl.was_called = False timezone = determine_timezone( - _environ={}, _path_exists=exists_timedatectl, _print=noop) + _environ={}, _path_exists=exists_timedatectl, _print=noop + ) self.assertTrue(exists_timedatectl.was_called) self.assertIsNotNone(timezone) @@ -131,46 +141,48 @@ def assertConfigKey(self, generated, section, key, expected): self.assertEqual(actual, expected) def test_creates_output_directory_and_adds_active_symlink(self): - symlink_path = temppath('confs', self) - cleanpath('confs-foobar', self) + symlink_path = temppath("confs", self) + cleanpath("confs-foobar", self) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") - path_kind = self.assertPathExists('confs-foobar') + path_kind = self.assertPathExists("confs-foobar") self.assertEqual(path_kind, "dir") - path_kind = self.assertPathExists('confs') + path_kind = self.assertPathExists("confs") self.assertEqual(path_kind, "symlink") def test_creates_output_directory_and_repairs_active_symlink(self): - expected_generated_dir = cleanpath('confs-foobar', self) - symlink_path = temppath('confs', self) - nowhere_path = temppath('confs-nowhere', self) + expected_generated_dir = cleanpath("confs-foobar", self) + symlink_path = temppath("confs", self) + nowhere_path = temppath("confs-nowhere", self) # arrange pre-existing symlink pointing nowhere os.symlink(nowhere_path, symlink_path) - generate_confs(self.output_path, destination_suffix='-foobar') + generate_confs(self.output_path, destination_suffix="-foobar") generated_dir = os.path.realpath(symlink_path) self.assertEqual(generated_dir, expected_generated_dir) - def test_creates_output_directory_containing_a_standard_local_configuration(self): + def test_creates_output_directory_containing_a_standard_local_configuration( + self, + ): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, - destination_suffix='-stdlocal', - user='testuser', - group='testgroup', - mig_code='/home/mig/mig', - mig_certs='/home/mig/certs', - mig_state='/home/mig/state', - timezone='Test/Place', - crypto_salt='_TEST_CRYPTO_SALT'.zfill(32), - digest_salt='_TEST_DIGEST_SALT'.zfill(32), - seafile_secret='_test-seafile-secret='.zfill(44), - seafile_ccnetid='_TEST_SEAFILE_CCNETID'.zfill(40), + destination_suffix="-stdlocal", + user="testuser", + group="testgroup", + mig_code="/home/mig/mig", + mig_certs="/home/mig/certs", + mig_state="/home/mig/state", + timezone="Test/Place", + crypto_salt="_TEST_CRYPTO_SALT".zfill(32), + digest_salt="_TEST_DIGEST_SALT".zfill(32), + seafile_secret="_test-seafile-secret=".zfill(44), + seafile_ccnetid="_TEST_SEAFILE_CCNETID".zfill(40), _getpwnam=create_dummy_gpwnam(4321, 1234), ) @@ -193,10 +205,11 @@ def test_kwargs_for_paths_auto(self): def capture_defaulted(*args, **kwargs): capture_defaulted.kwargs = kwargs return args[0] + capture_defaulted.kwargs = None - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=capture_defaulted, _writefiles=noop, @@ -204,123 +217,139 @@ def capture_defaulted(*args, **kwargs): ) defaulted = capture_defaulted.kwargs - self.assertPathWithin(defaulted['mig_certs'], MIG_BASE) - self.assertPathWithin(defaulted['mig_state'], MIG_BASE) + self.assertPathWithin(defaulted["mig_certs"], MIG_BASE) + self.assertPathWithin(defaulted["mig_state"], MIG_BASE) def test_creates_output_files_with_datasafety(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', - datasafety_link='TEST_DATASAFETY_LINK', - datasafety_text='TEST_DATASAFETY_TEXT', + destination_suffix="-stdlocal", + datasafety_link="TEST_DATASAFETY_LINK", + datasafety_text="TEST_DATASAFETY_TEXT", _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'datasafety_link', expected='TEST_DATASAFETY_LINK') + actual_file, + "SITE", + "datasafety_link", + expected="TEST_DATASAFETY_LINK", + ) self.assertConfigKey( - actual_file, 'SITE', 'datasafety_text', expected='TEST_DATASAFETY_TEXT') + actual_file, + "SITE", + "datasafety_text", + expected="TEST_DATASAFETY_TEXT", + ) def test_creates_output_files_with_permanent_freeze(self): fixture_dir = fixturepath("confs-stdlocal") - expected_generated_dir = cleanpath('confs-stdlocal', self) - symlink_path = temppath('confs', self) + expected_generated_dir = cleanpath("confs-stdlocal", self) + symlink_path = temppath("confs", self) - for arg_val in ('yes', 'no', 'foo bar baz'): + for arg_val in ("yes", "no", "foo bar baz"): generate_confs( self.output_path, destination=symlink_path, - destination_suffix='-stdlocal', + destination_suffix="-stdlocal", permanent_freeze=arg_val, _getpwnam=create_dummy_gpwnam(4321, 1234), ) - actual_file = self.assertFileExists('confs-stdlocal/MiGserver.conf') + actual_file = self.assertFileExists("confs-stdlocal/MiGserver.conf") self.assertConfigKey( - actual_file, 'SITE', 'permanent_freeze', expected=arg_val) + actual_file, "SITE", "permanent_freeze", expected=arg_val + ) def test_options_for_source_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", source=keyword_auto, _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - expected_template_dir = os.path.join(MIG_BASE, 'mig/install') + expected_template_dir = os.path.join(MIG_BASE, "mig/install") - self.assertEqual(options['template_dir'], expected_template_dir) + self.assertEqual(options["template_dir"], expected_template_dir) def test_options_for_source_relative(self): - (options, _) = generate_confs( - '/current/working/directory/mig/install', - source='.', + options, _ = generate_confs( + "/current/working/directory/mig/install", + source=".", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['template_dir'], - '/current/working/directory/mig/install') + self.assertEqual( + options["template_dir"], "/current/working/directory/mig/install" + ) def test_options_for_destination_auto(self): - (options, _) = generate_confs( - '/some/arbitrary/path', + options, _ = generate_confs( + "/some/arbitrary/path", destination=keyword_auto, - destination_suffix='_suffix', + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/arbitrary/path/confs') - self.assertEqual(options['destination_dir'], - '/some/arbitrary/path/confs_suffix') + self.assertEqual( + options["destination_link"], "/some/arbitrary/path/confs" + ) + self.assertEqual( + options["destination_dir"], "/some/arbitrary/path/confs_suffix" + ) def test_options_for_destination_relative(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='generate-confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="generate-confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/current/working/directory/generate-confs') - self.assertEqual(options['destination_dir'], - '/current/working/directory/generate-confs_suffix') + self.assertEqual( + options["destination_link"], + "/current/working/directory/generate-confs", + ) + self.assertEqual( + options["destination_dir"], + "/current/working/directory/generate-confs_suffix", + ) def test_options_for_destination_absolute(self): - (options, _) = generate_confs( - '/current/working/directory', - destination='/some/other/place/confs', - destination_suffix='_suffix', + options, _ = generate_confs( + "/current/working/directory", + destination="/some/other/place/confs", + destination_suffix="_suffix", _getpwnam=create_dummy_gpwnam(4321, 1234), _prepare=noop, _writefiles=noop, _instructions=noop, ) - self.assertEqual(options['destination_link'], - '/some/other/place/confs') - self.assertEqual(options['destination_dir'], - '/some/other/place/confs_suffix') + self.assertEqual(options["destination_link"], "/some/other/place/confs") + self.assertEqual( + options["destination_dir"], "/some/other/place/confs_suffix" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_jupyter.py b/tests/test_mig_shared_jupyter.py index 3eb53bf63..461282df8 100644 --- a/tests/test_mig_shared_jupyter.py +++ b/tests/test_mig_shared_jupyter.py @@ -34,9 +34,14 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain from mig.shared.jupyter import gen_openid_template +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) def noop(*args): @@ -48,7 +53,8 @@ class MigSharedJupyter(MigTestCase): def test_jupyter_gen_openid_template_openid_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "OpenID", _print=noop) + "/some-jupyter-url", "MyDefine", "OpenID", _print=noop + ) expected = """ @@ -63,7 +69,8 @@ def test_jupyter_gen_openid_template_openid_auth(self): def test_jupyter_gen_openid_template_oidc_auth(self): filled = gen_openid_template( - "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop) + "/some-jupyter-url", "MyDefine", "openid-connect", _print=noop + ) expected = """ @@ -79,26 +86,26 @@ def test_jupyter_gen_openid_template_oidc_auth(self): def test_jupyter_gen_openid_template_invalid_url_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template(None, "MyDefine", - "openid-connect") + filled = gen_openid_template(None, "MyDefine", "openid-connect") def test_jupyter_gen_openid_template_invalid_define_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", None, - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", None, "no-such-auth-type" + ) def test_jupyter_gen_openid_template_missing_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - None) + filled = gen_openid_template("/some-jupyter-url", "MyDefine", None) def test_jupyter_gen_openid_template_invalid_auth_type(self): with self.assertRaises(AssertionError): - filled = gen_openid_template("/some-jupyter-url", "MyDefine", - "no-such-auth-type") + filled = gen_openid_template( + "/some-jupyter-url", "MyDefine", "no-such-auth-type" + ) # TODO: add more coverage of module -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_localfile.py b/tests/test_mig_shared_localfile.py index 62a5112aa..be9718bd7 100644 --- a/tests/test_mig_shared_localfile.py +++ b/tests/test_mig_shared_localfile.py @@ -27,21 +27,19 @@ """Unit tests for the migrid module pointed to in the filename""" -from contextlib import contextmanager import errno import fcntl import os import sys +from contextlib import contextmanager -sys.path.append(os.path.realpath( - os.path.join(os.path.dirname(__file__), ".."))) - -from tests.support import MigTestCase, temppath, testmain +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))) -from mig.shared.serverfile import LOCK_EX from mig.shared.localfile import LocalFile +from mig.shared.serverfile import LOCK_EX +from tests.support import MigTestCase, temppath, testmain -DUMMY_FILE = 'some_file' +DUMMY_FILE = "some_file" @contextmanager @@ -72,7 +70,7 @@ def assertPathLockedExclusive(self, file_path): # we were errantly able to acquire a lock, mark errored reraise = AssertionError("RERAISE_MUST_UNLOCK") except Exception as maybe_err: - if getattr(maybe_err, 'errno', None) == errno.EAGAIN: + if getattr(maybe_err, "errno", None) == errno.EAGAIN: # this is the expected exception - the logic tried to lock # a file that was (as we intended) already locked, meaning # this assertion has succeeded so we do not need to raise @@ -83,17 +81,19 @@ def assertPathLockedExclusive(self, file_path): if reraise is not None: # if marked errored and locked, cleanup the lock we acquired but shouldn't - if str(reraise) == 'RERAISE_MUST_UNLOCK': + if str(reraise) == "RERAISE_MUST_UNLOCK": fcntl.flock(conflicting_f, fcntl.LOCK_NB | fcntl.LOCK_UN) # raise a user-friendly error to avoid nested raise raise AssertionError( - "expected locked file: %s" % self.pretty_display_path(file_path)) + "expected locked file: %s" + % self.pretty_display_path(file_path) + ) def test_localfile_locking(self): some_file = temppath(DUMMY_FILE, self) - with managed_localfile(LocalFile(some_file, 'w')) as lfd: + with managed_localfile(LocalFile(some_file, "w")) as lfd: lfd.lock(LOCK_EX) self.assertEqual(lfd.get_lock_mode(), LOCK_EX) @@ -101,5 +101,5 @@ def test_localfile_locking(self): self.assertPathLockedExclusive(some_file) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_pwcrypto.py b/tests/test_mig_shared_pwcrypto.py index 4be502d84..626bb94f1 100644 --- a/tests/test_mig_shared_pwcrypto.py +++ b/tests/test_mig_shared_pwcrypto.py @@ -32,64 +32,86 @@ import sys import unittest -from tests.support import MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain - -from mig.shared.defaults import POLICY_NONE, POLICY_WEAK, POLICY_MEDIUM, \ - POLICY_HIGH, POLICY_MODERN, POLICY_CUSTOM, PASSWORD_POLICIES +from mig.shared.defaults import ( + PASSWORD_POLICIES, + POLICY_CUSTOM, + POLICY_HIGH, + POLICY_MEDIUM, + POLICY_MODERN, + POLICY_NONE, + POLICY_WEAK, +) from mig.shared.pwcrypto import * +from tests.support import ( + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" # NOTE: these passwords are not and should not ever be used outside unit tests -DUMMY_WEAK_PW = 'foobar' -DUMMY_MEDIUM_PW = 'QZFnCp7h' -DUMMY_HIGH_PW = 'QZFnp7I-GZ' -DUMMY_MODERN_PW = 'QZFnCp7hmI1G' -DUMMY_GENERATED_PW = '7hmI1GnCpQZF' +DUMMY_WEAK_PW = "foobar" +DUMMY_MEDIUM_PW = "QZFnCp7h" +DUMMY_HIGH_PW = "QZFnp7I-GZ" +DUMMY_MODERN_PW = "QZFnCp7hmI1G" +DUMMY_GENERATED_PW = "7hmI1GnCpQZF" DUMMY_WEAK_PW_MD5 = "3858f62230ac3c915f300c664312c63f" -DUMMY_WEAK_PW_SHA256 = \ +DUMMY_WEAK_PW_SHA256 = ( "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2" -DUMMY_WEAK_PW_PBKDF2 = \ +) +DUMMY_WEAK_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$epib2rEg/HYTQZFnCp7hmIGZ6rzHnViy" -DUMMY_MEDIUM_PW_PBKDF2 = \ +) +DUMMY_MEDIUM_PW_PBKDF2 = ( "PBKDF2$sha256$10000$ebQHnDX1rzY9Rizb$0vUJ9/4ThhsN4cRaKYmOj4N0YKEsozTr" -DUMMY_HIGH_PW_PBKDF2 = \ +) +DUMMY_HIGH_PW_PBKDF2 = ( "PBKDF2$sha256$10000$HR+KcqLyQe3v0WSk$CtxMAomi8JHiI7gWc/PH5Ey00zW1Now3" +) DUMMY_MODERN_PW_MD5 = "a06d169a171ef7d4383b212457162d93" -DUMMY_MODERN_PW_SHA256 = \ +DUMMY_MODERN_PW_SHA256 = ( "d293dcb9762c87641ea1decbfe76d84ec51b13d6a1e688cdf1a838eebc5bb1a9" -DUMMY_MODERN_PW_PBKDF2 = \ +) +DUMMY_MODERN_PW_PBKDF2 = ( "PBKDF2$sha256$10000$MDAwMDAwMDAwMDAw$B22uw6C7C4VFiYAe4Vf10n58FHrn1pjX" -DUMMY_MODERN_PW_DIGEST = \ - "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" +) +DUMMY_MODERN_PW_DIGEST = "DIGEST$custom$CONFSALT$64756D6D792D7265616C6D3A64756D6D792D7520DE71261F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_DIGEST_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" DUMMY_MODERN_PW_SCRAMBLE = "53BB031C1F96A2FE48A67DD0877F2A2C" -DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b'xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=' -DUMMY_MODERN_PW_RESET_TOKEN = b'gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=' -DUMMY_METHOD = 'dummy-method' -DUMMY_OPERATION = 'dummy-operation' -DUMMY_ARGS = {'dummy-key': 'dummy-val'} -DUMMY_CSRF_TOKEN = '351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d' -DUMMY_CSRF_TRUST_TOKEN = '466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8' -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED = b"xRsT1qHmiM3xqDjuvFuxqQ==.g4-Gt83uRrdvVWwX0SF1iMza3NyKJbp2sEYVkw==.ICAgIG1pZ3JpZCBhdXRoZW50aWNhdGVkMjA1MDAxMDE=" +DUMMY_MODERN_PW_RESET_TOKEN = b"gAAAAABo63hYqeHE7Db93pMxWn1sWzj2Z-6td2UhA5gKYa4KV096ndV-AO0pp6hrR9jXKcwWAouLCMiNC0BRudeCAYHoBii15lTRbP9b7JzvJjeusbidjRxqcJg0om6bbtSK1Rz_RBTq_jhdAk4v-7PccWlZ15dVJ4j-X3X4zSsBWIOR5y6Y-bA=" +DUMMY_METHOD = "dummy-method" +DUMMY_OPERATION = "dummy-operation" +DUMMY_ARGS = {"dummy-key": "dummy-val"} +DUMMY_CSRF_TOKEN = ( + "351cc47e0cd5c155fa4c4d3d0a6f1ee8f20eeb293ba13d59ede9d2a789687d3d" +) +DUMMY_CSRF_TRUST_TOKEN = ( + "466c0bacd045a060a201c4e08c749c2e19743613422e0381ab0a57706c9fa2b8" +) +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" # TODO: adjust password reset token helpers to handle configured services # it currently silently fails if not in migoid(c) or migcert # DUMMY_SERVICE = 'dummy-svc' -DUMMY_SERVICE = 'migoid' -DUMMY_REALM = 'dummy-realm' -DUMMY_PATH = 'dummy-path' -DUMMY_PATH_MD5 = 'd19033877452e8c217d3cddebbc37419' -DUMMY_SALT = b'53BB031C4ECCE4900BD64AB8EA361B6B' -DUMMY_ENTROPY = b'\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_FERNET_KEY = 'NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=' -DUMMY_AESGCM_KEY = b'48797271554677786167464ED6F23EC6' -DUMMY_AESGCM_STATIC_IV = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD_PREFIX = b'\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9' -DUMMY_AESGCM_AAD = b' \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101' +DUMMY_SERVICE = "migoid" +DUMMY_REALM = "dummy-realm" +DUMMY_PATH = "dummy-path" +DUMMY_PATH_MD5 = "d19033877452e8c217d3cddebbc37419" +DUMMY_SALT = b"53BB031C4ECCE4900BD64AB8EA361B6B" +DUMMY_ENTROPY = b"\xd2\x93\xdc\xb9v,\x87d\x1e\xa1\xde\xcb\xfev\xd8N\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +DUMMY_FERNET_KEY = "NDg3OTcyNzE1NTQ2Nzc3ODYxNjc0NjRFRDZGMjNFQzY=" +DUMMY_AESGCM_KEY = b"48797271554677786167464ED6F23EC6" +DUMMY_AESGCM_STATIC_IV = ( + b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9" +) +DUMMY_AESGCM_AAD_PREFIX = b"\xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa9" +DUMMY_AESGCM_AAD = b" \xc5\x1b\x13\xd6\xa1\xe6\x88\xcd\xf1\xa88\xee\xbc[\xb1\xa9\xa88\xee\xbc[\xb1\xa920500101" # NOTE: we avoid any percent expansion values of actual date here to freeze AAD -DUMMY_FIXED_TIMESTAMP = '20500101' +DUMMY_FIXED_TIMESTAMP = "20500101" class MigSharedPwCrypto(MigTestCase): @@ -98,13 +120,15 @@ class MigSharedPwCrypto(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration self.dummy_conf = FakeConfiguration( - user_home=test_user_home, user_settings=test_user_settings, + user_home=test_user_home, + user_settings=test_user_settings, site_password_policy="%s:12" % POLICY_MODERN, site_password_legacy_policy=POLICY_MEDIUM, site_password_cracklib=False, @@ -119,9 +143,11 @@ def before_each(self): # 'FakeConfiguration' has no 'site_password_legacy_policy' member # (no-member) unless we explicitly (re-)init it here self.dummy_conf.site_password_legacy_policy = getattr( - self.dummy_conf, 'site_password_legacy_policy', POLICY_NONE) - self.assertEqual(self.dummy_conf.site_password_legacy_policy, - POLICY_MEDIUM) + self.dummy_conf, "site_password_legacy_policy", POLICY_NONE + ) + self.assertEqual( + self.dummy_conf.site_password_legacy_policy, POLICY_MEDIUM + ) def test_best_crypt_salt(self): """Test selection of best salt based on salt availability in @@ -130,13 +156,13 @@ def test_best_crypt_salt(self): expected = DUMMY_SALT actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "best crypt salt not found") - self.dummy_conf.site_crypto_salt = '' + self.dummy_conf.site_crypto_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "2nd best crypt salt not found") - self.dummy_conf.site_password_salt = '' + self.dummy_conf.site_password_salt = "" actual = best_crypt_salt(self.dummy_conf) self.assertEqual(actual, expected, "3rd best crypt salt not found") - self.dummy_conf.site_digest_salt = '' + self.dummy_conf.site_digest_salt = "" actual = None try: actual = best_crypt_salt(self.dummy_conf) @@ -153,10 +179,10 @@ def test_password_requirements(self): self.assertEqual(expected[2], result[2], "failed pw req errors") expected = (8, 3, []) result = password_requirements( - self.dummy_conf.site_password_legacy_policy) + self.dummy_conf.site_password_legacy_policy + ) self.assertEqual(expected[0], result[0], "failed legacy pw req chars") - self.assertEqual(expected[1], result[1], - "failed legacy pw req classes") + self.assertEqual(expected[1], result[1], "failed legacy pw req classes") self.assertEqual(expected[2], result[2], "failed legacy pw req errors") def test_parse_password_policy(self): @@ -173,56 +199,60 @@ def test_parse_password_policy(self): def test_assure_password_strength(self): """Test assure password strength""" try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw") try: - allow_weak = assure_password_strength(self.dummy_conf, - DUMMY_WEAK_PW, - allow_legacy=True) + allow_weak = assure_password_strength( + self.dummy_conf, DUMMY_WEAK_PW, allow_legacy=True + ) except ValueError as vae: allow_weak = False self.assertFalse(allow_weak, "allowed weak pw with legacy") # NOTE: only allow medium with legacy try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW + ) except ValueError as vae: allow_medium = False self.assertFalse(allow_medium, "allowed medium pw without legacy") try: - allow_medium = assure_password_strength(self.dummy_conf, - DUMMY_MEDIUM_PW, - allow_legacy=True) + allow_medium = assure_password_strength( + self.dummy_conf, DUMMY_MEDIUM_PW, allow_legacy=True + ) except ValueError as vae: allow_medium = False self.assertTrue(allow_medium, "refused medium pw with legacy") # NOTE: only allow high with legacy - not long enough for modern try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW + ) except ValueError as vae: allow_high = False self.assertFalse(allow_high, "allowed high pw without legacy") try: - allow_high = assure_password_strength(self.dummy_conf, - DUMMY_HIGH_PW, - allow_legacy=True) + allow_high = assure_password_strength( + self.dummy_conf, DUMMY_HIGH_PW, allow_legacy=True + ) except ValueError as vae: allow_high = False self.assertTrue(allow_high, "refused high pw with legacy") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw") try: - allow_modern = assure_password_strength(self.dummy_conf, - DUMMY_MODERN_PW, - allow_legacy=True) + allow_modern = assure_password_strength( + self.dummy_conf, DUMMY_MODERN_PW, allow_legacy=True + ) except ValueError as vae: allow_modern = False self.assertTrue(allow_modern, "refused modern pw with legacy") @@ -277,7 +307,7 @@ def test_make_hash_fixed_seed(self): random seed. """ expected = DUMMY_MODERN_PW_PBKDF2 - actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b'0' * vlen) + actual = make_hash(DUMMY_MODERN_PW, _urandom=lambda vlen: b"0" * vlen) self.assertEqual(actual, expected, "mismatch hashing string") def test_make_hash_constant_string(self): @@ -285,17 +315,25 @@ def test_make_hash_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) self.assertEqual(first, second, "same seed hashing is not constant") def test_check_hash_reject_weak(self): """Test basic hash checking of a constant weak complexity password""" expected = DUMMY_WEAK_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_WEAK_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_WEAK_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should fail on weak pw") def test_check_hash_reject_medium_without_legacy(self): @@ -303,9 +341,15 @@ def test_check_hash_reject_medium_without_legacy(self): without legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertFalse(result, "check hash strict should fail on medium pw") def test_check_hash_accept_medium_with_legacy(self): @@ -313,9 +357,15 @@ def test_check_hash_accept_medium_with_legacy(self): with legacy password support. """ expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + allow_legacy=True, + ) self.assertTrue(result, "check hash with legacy must accept medium pw") def test_check_hash_accept_high(self): @@ -324,9 +374,15 @@ def test_check_hash_accept_high(self): """ expected = DUMMY_HIGH_PW_PBKDF2 self.dummy_conf.site_password_policy = POLICY_HIGH - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_HIGH_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_HIGH_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept high complexity pw") def test_check_hash_accept_modern(self): @@ -334,32 +390,57 @@ def test_check_hash_accept_modern(self): without legacy password support. """ expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True, - allow_legacy=False) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + allow_legacy=False, + ) self.assertTrue(result, "check hash must accept modern complexity pw") def test_check_hash_fixed(self): """Test basic hash checking of a fixed string""" expected = DUMMY_MEDIUM_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=True, + ) self.assertFalse(result, "check hash should reject medium pw") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MEDIUM_PW, expected, strict_policy=False, - allow_legacy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MEDIUM_PW, + expected, + strict_policy=False, + allow_legacy=True, + ) self.assertTrue(result, "check hash failed medium pw when not strict") expected = DUMMY_MODERN_PW_PBKDF2 - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, strict_policy=True) + result = check_hash( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + strict_policy=True, + ) self.assertTrue(result, "check hash failed modern pw") def test_check_hash_random(self): """Test basic hashing and hash checking of a random string""" random_pw = generate_random_password(self.dummy_conf) expected = make_hash(random_pw) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, expected) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, random_pw, expected + ) self.assertTrue(result, "mismatch in random hash and check") def test_make_hash_variation(self): @@ -367,10 +448,12 @@ def test_make_hash_variation(self): I.e. the value likely remains constant in same interpreter but differs across interpreter invocations. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) self.assertNotEqual(first, second, "varying seed hashing is constant") def test_check_hash_despite_variation(self): @@ -378,15 +461,19 @@ def test_check_hash_despite_variation(self): I.e. the hash value differs across interpreter invocations but testing the same password against each succeeds. """ - first = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[:vlen]) - second = make_hash(DUMMY_MODERN_PW, - _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen]) - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, first) + first = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[:vlen] + ) + second = make_hash( + DUMMY_MODERN_PW, _urandom=lambda vlen: DUMMY_SALT[::-1][:vlen] + ) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, first + ) self.assertTrue(result, "mismatch in 1st random password hash check") - result = check_hash(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, second) + result = check_hash( + self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, DUMMY_MODERN_PW, second + ) self.assertTrue(result, "mismatch in 2nd random password hash check") def test_scramble_digest_fixed(self): @@ -404,16 +491,23 @@ def test_unscramble_digest_fixed(self): def test_make_digest_fixed(self): """Test basic digest of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + result = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(expected, result, "mismatch in fixed digest") def test_check_digest_fixed(self): """Test basic digest checking of a fixed string""" expected = DUMMY_MODERN_PW_DIGEST - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, DUMMY_MODERN_PW, expected, - DUMMY_SALT) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed digest check") def test_check_digest_random(self): @@ -421,8 +515,15 @@ def test_check_digest_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_digest(DUMMY_REALM, DUMMY_USER, random_pw, random_salt) - result = check_digest(self.dummy_conf, DUMMY_SERVICE, DUMMY_REALM, - DUMMY_USER, random_pw, expected, random_salt) + result = check_digest( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_REALM, + DUMMY_USER, + random_pw, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random digest check") def test_digest_constant_string(self): @@ -430,10 +531,12 @@ def test_digest_constant_string(self): random seed. I.e. the value may differ across interpreter invocations but remains constant in same interpreter. """ - first = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) - second = make_digest(DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, - DUMMY_SALT) + first = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) + second = make_digest( + DUMMY_REALM, DUMMY_USER, DUMMY_MODERN_PW, DUMMY_SALT + ) self.assertEqual(first, second, "basic digest is not constant") def test_scramble_password_fixed(self): @@ -457,8 +560,14 @@ def test_make_scramble_fixed(self): def test_check_scramble_fixed(self): """Test basic scramble checking of a fixed string""" expected = DUMMY_MODERN_PW_SCRAMBLE - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, DUMMY_SALT) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + DUMMY_SALT, + ) self.assertTrue(result, "mismatch in fixed scramble check") def test_check_scramble_random(self): @@ -466,8 +575,14 @@ def test_check_scramble_random(self): random_pw = generate_random_password(self.dummy_conf) random_salt = base64.b16encode(os.urandom(16)) expected = make_scramble(DUMMY_MODERN_PW, random_salt) - result = check_scramble(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - DUMMY_MODERN_PW, expected, random_salt) + result = check_scramble( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + DUMMY_MODERN_PW, + expected, + random_salt, + ) self.assertTrue(result, "mismatch in random scramble check") def test_scramble_constant_string(self): @@ -479,8 +594,7 @@ def test_scramble_constant_string(self): self.assertEqual(first, second, "basic scramble is not constant") def test_prepare_fernet_key(self): - """Test basic fernet secret key preparation on a fixed string. - """ + """Test basic fernet secret key preparation on a fixed string.""" expected = DUMMY_FERNET_KEY result = prepare_fernet_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare fernet key") @@ -493,8 +607,7 @@ def test_fernet_encrypt_decrypt(self): self.assertEqual(random_pw, result, "failed fernet enc+dec") def test_prepare_aesgcm_key(self): - """Test basic aesgcm secret key preparation on a fixed string. - """ + """Test basic aesgcm secret key preparation on a fixed string.""" expected = DUMMY_AESGCM_KEY result = prepare_aesgcm_key(self.dummy_conf) self.assertEqual(expected, result, "failed prepare aesgcm key") @@ -519,8 +632,11 @@ def test_prepare_aesgcm_aad_fixed(self): entropy and date value. """ expected = DUMMY_AESGCM_AAD - result = prepare_aesgcm_aad(self.dummy_conf, DUMMY_AESGCM_AAD_PREFIX, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = prepare_aesgcm_aad( + self.dummy_conf, + DUMMY_AESGCM_AAD_PREFIX, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed prepare aesgcm aad") def test_aesgcm_encrypt_static_iv_fixed(self): @@ -528,9 +644,12 @@ def test_aesgcm_encrypt_static_iv_fixed(self): initialization vector and date helper. """ expected = DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED - result = aesgcm_encrypt_password(self.dummy_conf, DUMMY_MODERN_PW, - init_vector=DUMMY_AESGCM_STATIC_IV, - aad_stamp=DUMMY_FIXED_TIMESTAMP) + result = aesgcm_encrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW, + init_vector=DUMMY_AESGCM_STATIC_IV, + aad_stamp=DUMMY_FIXED_TIMESTAMP, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv enc") def test_aesgcm_decrypt_static_iv_fixed(self): @@ -538,9 +657,11 @@ def test_aesgcm_decrypt_static_iv_fixed(self): initialization vector. """ expected = DUMMY_MODERN_PW - result = aesgcm_decrypt_password(self.dummy_conf, - DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, - init_vector=DUMMY_AESGCM_STATIC_IV) + result = aesgcm_decrypt_password( + self.dummy_conf, + DUMMY_MODERN_PW_AESGCM_SIV_ENCRYPTED, + init_vector=DUMMY_AESGCM_STATIC_IV, + ) self.assertEqual(expected, result, "failed fixed aesgcm static iv den") def test_aesgcm_encrypt_decrypt_static_iv(self): @@ -550,10 +671,12 @@ def test_aesgcm_encrypt_decrypt_static_iv(self): random_pw = generate_random_password(self.dummy_conf) entropy = make_safe_hash(random_pw, False) static_iv = prepare_aesgcm_iv(self.dummy_conf, iv_entropy=entropy) - expected = aesgcm_encrypt_password(self.dummy_conf, random_pw, - init_vector=static_iv) - result = aesgcm_decrypt_password(self.dummy_conf, expected, - init_vector=static_iv) + expected = aesgcm_encrypt_password( + self.dummy_conf, random_pw, init_vector=static_iv + ) + result = aesgcm_decrypt_password( + self.dummy_conf, expected, init_vector=static_iv + ) self.assertEqual(random_pw, result, "failed aesgcm static iv enc+dec") def test_make_encrypt_decrypt(self): @@ -581,97 +704,127 @@ def test_check_encrypt(self): """Test basic password simple encrypt and decrypt on a random string""" random_pw = generate_random_password(self.dummy_conf) # IMPORTANT: only aesgcm_static generates constant enc value! - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="fernet") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='fernet') + encrypted = make_encrypt(self.dummy_conf, random_pw, algo="fernet") + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="fernet", + ) self.assertFalse(result, "invalid match in fernet encrypt check") encrypted = make_encrypt(self.dummy_conf, random_pw, algo="aesgcm") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm') + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm", + ) self.assertFalse(result, "invalid match in aesgcm encrypt check") - encrypted = make_encrypt(self.dummy_conf, random_pw, - algo="aesgcm_static") - result = check_encrypt(self.dummy_conf, DUMMY_SERVICE, DUMMY_USER, - random_pw, encrypted, algo='aesgcm_static') + encrypted = make_encrypt( + self.dummy_conf, random_pw, algo="aesgcm_static" + ) + result = check_encrypt( + self.dummy_conf, + DUMMY_SERVICE, + DUMMY_USER, + random_pw, + encrypted, + algo="aesgcm_static", + ) self.assertTrue(result, "mismatch in aesgcm_static encrypt check") def test_assure_reset_supported(self): """Test basic password reset token check for a fixed user and auth""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 - result = assure_reset_supported(self.dummy_conf, dummy_user, - DUMMY_SERVICE) + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 + result = assure_reset_supported( + self.dummy_conf, dummy_user, DUMMY_SERVICE + ) self.assertTrue(result, "failed assure reset supported") # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_generate_reset_token_fixed(self): """Test basic password reset token generate for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 expected = DUMMY_MODERN_PW_RESET_TOKEN - result = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) - self.assertEqual(expected, result, - "failed generate password reset token") + result = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) + self.assertEqual( + expected, result, "failed generate password reset token" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_parse_reset_token_fixed(self): """Test basic password reset token parse for a fixed string""" timestamp = 42 - result = parse_reset_token(self.dummy_conf, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE) + result = parse_reset_token( + self.dummy_conf, DUMMY_MODERN_PW_RESET_TOKEN, DUMMY_SERVICE + ) self.assertEqual(result[0], timestamp, "failed parse token time") - self.assertEqual(result[1], DUMMY_MODERN_PW_PBKDF2, - "failed parse token hash") + self.assertEqual( + result[1], DUMMY_MODERN_PW_PBKDF2, "failed parse token hash" + ) # TODO: adjust API to allow enabling the next test @unittest.skipIf(True, "requires constant random seed") def test_verify_reset_token_fixed(self): """Test basic password reset token verify for a fixed string""" - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = DUMMY_MODERN_PW_PBKDF2 + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = DUMMY_MODERN_PW_PBKDF2 timestamp = 42 - result = verify_reset_token(self.dummy_conf, dummy_user, - DUMMY_MODERN_PW_RESET_TOKEN, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, + dummy_user, + DUMMY_MODERN_PW_RESET_TOKEN, + DUMMY_SERVICE, + timestamp, + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_generate_and_verify(self): """Test basic password reset token generate and verify helper""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertTrue(result, "failed password reset token handling") def test_password_reset_token_verify_expired(self): """Test basic password reset token verify failure after it expired""" random_pw = generate_random_password(self.dummy_conf) hashed_pw = make_hash(random_pw) - dummy_user = {'distinguished_name': DUMMY_USER} - dummy_user['password_hash'] = hashed_pw + dummy_user = {"distinguished_name": DUMMY_USER} + dummy_user["password_hash"] = hashed_pw timestamp = 42 - expected = generate_reset_token(self.dummy_conf, dummy_user, - DUMMY_SERVICE, timestamp) + expected = generate_reset_token( + self.dummy_conf, dummy_user, DUMMY_SERVICE, timestamp + ) parsed = parse_reset_token(self.dummy_conf, expected, DUMMY_SERVICE) self.assertEqual(parsed[0], timestamp, "failed parse token time") self.assertEqual(parsed[1], hashed_pw, "failed parse token hash") timestamp = 4242 - result = verify_reset_token(self.dummy_conf, dummy_user, expected, - DUMMY_SERVICE, timestamp) + result = verify_reset_token( + self.dummy_conf, dummy_user, expected, DUMMY_SERVICE, timestamp + ) self.assertFalse(result, "failed password reset token expiry check") def test_make_csrf_token_fixed(self): @@ -679,8 +832,9 @@ def test_make_csrf_token_fixed(self): client id. """ expected = DUMMY_CSRF_TOKEN - result = make_csrf_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ID) + result = make_csrf_token( + self.dummy_conf, DUMMY_METHOD, DUMMY_OPERATION, DUMMY_ID + ) self.assertEqual(expected, result, "failed make csrf token") def test_make_csrf_trust_token_fixed(self): @@ -688,8 +842,9 @@ def test_make_csrf_trust_token_fixed(self): client id and args. """ expected = DUMMY_CSRF_TRUST_TOKEN - result = make_csrf_trust_token(self.dummy_conf, DUMMY_METHOD, - DUMMY_OPERATION, DUMMY_ARGS, DUMMY_ID) + result = make_csrf_trust_token( + self.dummy_conf, DUMMY_METHOD, DUMMY_OPERATION, DUMMY_ARGS, DUMMY_ID + ) self.assertEqual(expected, result, "failed make csrf trust token") def test_generate_random_password(self): @@ -704,15 +859,16 @@ def test_generate_random_password_fixed_seed(self): """Test basic generate password is constant with fixed random seed""" expected = DUMMY_GENERATED_PW result = generate_random_password(self.dummy_conf) - self.assertEqual(expected, result, - "failed generate password with fixed seed") + self.assertEqual( + expected, result, "failed generate password with fixed seed" + ) class MigSharedPwCrypto__legacy_main(MigTestCase): """Legacy tests for corresponding module self-checks""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" # TODO: migrate remaining inline checks from module here instead def test_existing_main(self): @@ -721,16 +877,22 @@ def raise_on_error_exit(exit_code): if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeeval.py b/tests/test_mig_shared_safeeval.py index 4bec4aa90..c015875be 100644 --- a/tests/test_mig_shared_safeeval.py +++ b/tests/test_mig_shared_safeeval.py @@ -30,13 +30,11 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.shared.safeeval import * - +from tests.support import MigTestCase, testmain PWD_STR = os.getcwd() -PWD_BYTES = PWD_STR.encode('utf8') +PWD_BYTES = PWD_STR.encode("utf8") class MigSharedSafeeval(MigTestCase): @@ -44,49 +42,57 @@ class MigSharedSafeeval(MigTestCase): def test_subprocess_call(self): """Check that pwd call without args succeeds""" - retval = subprocess_call(['pwd'], stdout=subprocess_pipe) + retval = subprocess_call(["pwd"], stdout=subprocess_pipe) self.assertEqual(retval, 0, "unexpected subprocess call pwd retval") def test_subprocess_call_invalid(self): """Check that pwd call with invalid arg fails""" - retval = subprocess_call(['pwd', '-h'], stderr=subprocess_pipe) - self.assertNotEqual(retval, 0, - "unexpected subprocess call nosuchcommand retval") + retval = subprocess_call(["pwd", "-h"], stderr=subprocess_pipe) + self.assertNotEqual( + retval, 0, "unexpected subprocess call nosuchcommand retval" + ) def test_subprocess_check_output(self): """Check that pwd command output matches getcwd as bytes""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe).strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe + ).strip() + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess check pwd output" + ) def test_subprocess_check_output_text(self): """Check that pwd command output matches getcwd as string""" - data = subprocess_check_output(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_pipe, - text=True).strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess check pwd output") + data = subprocess_check_output( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_pipe, text=True + ).strip() + self.assertEqual( + data, PWD_STR, "mismatch in subprocess check pwd output" + ) def test_subprocess_popen(self): """Check that pwd popen output matches getcwd as bytes""" - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout) + proc = subprocess_popen( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_stdout + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_BYTES, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_BYTES, "mismatch in subprocess popen pwd output" + ) def test_subprocess_popen_text(self): """Check that pwd popen output matches getcwd as string""" orig = os.getcwd() - proc = subprocess_popen(['pwd'], stdout=subprocess_pipe, - stderr=subprocess_stdout, text=True) + proc = subprocess_popen( + ["pwd"], stdout=subprocess_pipe, stderr=subprocess_stdout, text=True + ) retval = proc.wait() data = proc.stdout.read().strip() - self.assertEqual(data, PWD_STR, - "mismatch in subprocess popen pwd output") + self.assertEqual( + data, PWD_STR, "mismatch in subprocess popen pwd output" + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_safeinput.py b/tests/test_mig_shared_safeinput.py index 4a01dddd8..15e8b91d6 100644 --- a/tests/test_mig_shared_safeinput.py +++ b/tests/test_mig_shared_safeinput.py @@ -30,15 +30,26 @@ import base64 import codecs import sys + from past.builtins import basestring, unicode +from mig.shared.safeinput import ( + VALID_NAME_CHARACTERS, + InputException, + filter_commonname, +) +from mig.shared.safeinput import main as safeinput_main +from mig.shared.safeinput import ( + valid_alphanumeric, + valid_base_url, + valid_commonname, + valid_complex_url, + valid_path, + valid_printable, + valid_url, +) from tests.support import MigTestCase, testmain -from mig.shared.safeinput import main as safeinput_main, InputException, \ - filter_commonname, valid_alphanumeric, valid_commonname, valid_path, \ - valid_printable, valid_base_url, valid_url, valid_complex_url, \ - VALID_NAME_CHARACTERS - PY2 = sys.version_info[0] == 2 @@ -46,18 +57,18 @@ def as_string_of_unicode(value): assert isinstance(value, basestring) if not is_string_of_unicode(value): assert PY2, "unreachable unless Python 2" - return unicode(codecs.decode(value, 'utf8')) + return unicode(codecs.decode(value, "utf8")) return value def is_string_of_unicode(value): - return type(value) == type(u'') + return type(value) == type("") def _hex_wrap(val): """Insert a clearly marked hex representation of val""" # Please keep aligned with helper in mig/shared/functionality/autocreate.py - return ".X%s" % base64.b16encode(val.encode('utf8')).decode('utf8') + return ".X%s" % base64.b16encode(val.encode("utf8")).decode("utf8") class TestMigSharedSafeInput(MigTestCase): @@ -69,7 +80,7 @@ class TestMigSharedSafeInput(MigTestCase): PRINTABLE_CHARS = "abc123!@#" ACCENTED_VALID = "Renée Müller" ACCENTED_INVALID_EXOTIC = "Źaćâř" - DECOMPOSED_UNICODE = u"å" # a + combining ring above + DECOMPOSED_UNICODE = "å" # a + combining ring above # Commonname specific test constants APOSTROPHE_FULL_NAME = "John O'Connor" @@ -77,26 +88,28 @@ class TestMigSharedSafeInput(MigTestCase): APOSTROPHE_FULL_NAME_HEX = "John O.X27Connor" COMMONNAME_PERMITTED = ( - 'Firstname Lastname', - 'Test Æøå', - 'Test Überh4x0r', - 'Harry S. Truman', - u'Unicode æøå') + "Firstname Lastname", + "Test Æøå", + "Test Überh4x0r", + "Harry S. Truman", + "Unicode æøå", + ) COMMONNAME_PROHIBITED = ( "Invalid D'Angelo", - 'Test Maybe Invalid Źacãŕ', - 'Test Invalid ?', - 'Test HTML Invalid ') + "Test Maybe Invalid Źacãŕ", + "Test Invalid ?", + "Test HTML Invalid ", + ) - BASE_URL = 'https://www.migrid.org' - REGULAR_URL = 'https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v' - COMPLEX_URL = 'https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE' - INVALID_URL = 'https://www.migrid.org/¾½§' + BASE_URL = "https://www.migrid.org" + REGULAR_URL = "https://www.migrid.org/wsgi-bin/ls.py?path=README&flags=v" + COMPLEX_URL = "https://www.migrid.org/abc123@some.org/ls.py?path=R+D#HERE" + INVALID_URL = "https://www.migrid.org/¾½§" def _provide_configuration(self): """Provide test configuration""" - return 'testconfig' + return "testconfig" def test_commonname_valid(self): """Test valid_commonname with acceptable and prohibited names""" @@ -130,7 +143,7 @@ def test_commonname_filter(self): self.assertTrue(len(filtered_cn) < len(test_cn_unicode)) # With default skip all chars in filtered_cn must be in original overlap = [i for i in filtered_cn if i in test_cn_unicode] - self.assertEqual(''.join(overlap), filtered_cn) + self.assertEqual("".join(overlap), filtered_cn) def test_commonname_filter_hexlify_illegal(self): """Test filter_commonname with hex encoding of illegal chars""" @@ -145,21 +158,23 @@ def test_commonname_filter_hexlify_illegal(self): filtered_cn = filter_commonname(test_cn, illegal_handler=_hex_wrap) # Invalid should be replaced with hexlify illegal_handler self.assertNotEqual(filtered_cn, test_cn_unicode) - self.assertIn('.X', filtered_cn) + self.assertIn(".X", filtered_cn) self.assertTrue(len(filtered_cn) > len(test_cn_unicode)) def test_filter_commonname_apostrophe_name_skip_illegal(self): """Test apostrophe handling with skip illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=None) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=None + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_SKIP) def test_filter_commonname_apostrophe_name_hexlify_illegal(self): """Test apostrophe handling with hex encode illegal_handler""" - result = filter_commonname(self.APOSTROPHE_FULL_NAME, - illegal_handler=_hex_wrap) + result = filter_commonname( + self.APOSTROPHE_FULL_NAME, illegal_handler=_hex_wrap + ) self.assertNotEqual(result, self.APOSTROPHE_FULL_NAME) self.assertNotIn("'", result) self.assertEqual(result, self.APOSTROPHE_FULL_NAME_HEX) @@ -283,14 +298,17 @@ class TestMigSharedSafeInput__legacy(MigTestCase): # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): """Run built-in self-tests and check output""" + def raise_on_error_exit(exit_code): if exit_code != 0: if raise_on_error_exit.last_print is not None: identifying_message = raise_on_error_exit.last_print else: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'failure in unittest/testcore: %s' % (identifying_message,)) + "failure in unittest/testcore: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): @@ -300,5 +318,5 @@ def record_last_print(value): safeinput_main(_exit=raise_on_error_exit, _print=record_last_print) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_serial.py b/tests/test_mig_shared_serial.py index c577e0b4e..c42dc0d8b 100644 --- a/tests/test_mig_shared_serial.py +++ b/tests/test_mig_shared_serial.py @@ -30,12 +30,17 @@ import os import sys +from mig.shared.serial import * from tests.support import MigTestCase, temppath, testmain -from mig.shared.serial import * class BasicSerial(MigTestCase): - BASIC_OBJECT = {'abc': 123, 'def': 'def', 'ghi': 42.0, 'accented': 'TéstÆøå'} + BASIC_OBJECT = { + "abc": 123, + "def": "def", + "ghi": 42.0, + "accented": "TéstÆøå", + } def test_pickle_string(self): orig = BasicSerial.BASIC_OBJECT @@ -49,5 +54,6 @@ def test_pickle_file(self): data = load(tmp_path) self.assertEqual(data, orig, "mismatch pickling string") -if __name__ == '__main__': + +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_settings.py b/tests/test_mig_shared_settings.py index 87b9e02e1..dabcbbc9e 100644 --- a/tests/test_mig_shared_settings.py +++ b/tests/test_mig_shared_settings.py @@ -32,23 +32,31 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.settings import load_settings, update_settings, \ - parse_and_save_settings +from mig.shared.settings import ( + load_settings, + parse_and_save_settings, + update_settings, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) DUMMY_USER = "dummy-user" -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_SETTINGS_DIR = "dummy_user_settings" DUMMY_SETTINGS_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SETTINGS_DIR) -DUMMY_SYSTEM_FILES_DIR = 'dummy_system_files' +DUMMY_SYSTEM_FILES_DIR = "dummy_system_files" DUMMY_SYSTEM_FILES_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_SYSTEM_FILES_DIR) -DUMMY_TMP_DIR = 'dummy_tmp' -DUMMY_TMP_FILE = 'settings.mRSL' +DUMMY_TMP_DIR = "dummy_tmp" +DUMMY_TMP_FILE = "settings.mRSL" DUMMY_TMP_PATH = os.path.join(TEST_OUTPUT_DIR, DUMMY_TMP_DIR) DUMMY_MRSL_PATH = os.path.join(DUMMY_TMP_PATH, DUMMY_TMP_FILE) -DUMMY_USER_INTERFACE = ['V3', 'V42'] -DUMMY_DEFAULT_UI = 'V42' +DUMMY_USER_INTERFACE = ["V3", "V42"] +DUMMY_DEFAULT_UI = "V42" DUMMY_INIT_MRSL = """ ::EMAIL:: john@doe.org @@ -65,10 +73,12 @@ ::SITE_USER_MENU:: people """ -DUMMY_CONF = FakeConfiguration(user_settings=DUMMY_SETTINGS_PATH, - mig_system_files=DUMMY_SYSTEM_FILES_PATH, - user_interface=DUMMY_USER_INTERFACE, - new_user_default_ui=DUMMY_DEFAULT_UI) +DUMMY_CONF = FakeConfiguration( + user_settings=DUMMY_SETTINGS_PATH, + mig_system_files=DUMMY_SYSTEM_FILES_PATH, + user_interface=DUMMY_USER_INTERFACE, + new_user_default_ui=DUMMY_DEFAULT_UI, +) class MigSharedSettings(MigTestCase): @@ -82,28 +92,31 @@ def test_settings_save_load(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, 'settings') + saved_path = os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER, "settings") self.assertTrue(os.path.exists(saved_path)) settings = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(settings, dict)) - self.assertEqual(settings['EMAIL'], ['john@doe.org']) - self.assertEqual(settings['SITE_USER_MENU'], - ['sharelinks', 'people', 'peers']) + self.assertEqual(settings["EMAIL"], ["john@doe.org"]) + self.assertEqual( + settings["SITE_USER_MENU"], ["sharelinks", "people", "peers"] + ) # NOTE: we no longer auto save default values for optional vars for key in settings.keys(): - self.assertTrue(key in ['EMAIL', 'SITE_USER_MENU']) + self.assertTrue(key in ["EMAIL", "SITE_USER_MENU"]) # Any saved USER_INTERFACE value must match configured default if set - self.assertEqual(settings.get('USER_INTERFACE', DUMMY_DEFAULT_UI), - DUMMY_DEFAULT_UI) + self.assertEqual( + settings.get("USER_INTERFACE", DUMMY_DEFAULT_UI), DUMMY_DEFAULT_UI + ) def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -113,25 +126,27 @@ def test_settings_replace(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_UPDATE_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) updated = load_settings(DUMMY_USER, DUMMY_CONF) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['jane@doe.org']) - self.assertEqual(updated['SITE_USER_MENU'], ['people']) + self.assertEqual(updated["EMAIL"], ["jane@doe.org"]) + self.assertEqual(updated["SITE_USER_MENU"], ["people"]) def test_update_settings(self): os.makedirs(os.path.join(DUMMY_SETTINGS_PATH, DUMMY_USER)) @@ -141,20 +156,21 @@ def test_update_settings(self): os.makedirs(os.path.join(DUMMY_TMP_PATH)) cleanpath(DUMMY_TMP_DIR, self) - with open(DUMMY_MRSL_PATH, 'w') as mrsl_fd: + with open(DUMMY_MRSL_PATH, "w") as mrsl_fd: mrsl_fd.write(DUMMY_INIT_MRSL) save_status, save_msg = parse_and_save_settings( - DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF) + DUMMY_MRSL_PATH, DUMMY_USER, DUMMY_CONF + ) self.assertTrue(save_status) self.assertFalse(save_msg) - changes = {'EMAIL': ['john@doe.org', 'jane@doe.org']} + changes = {"EMAIL": ["john@doe.org", "jane@doe.org"]} defaults = {} updated = update_settings(DUMMY_USER, DUMMY_CONF, changes, defaults) # NOTE: updated should be a non-empty dict at this point self.assertTrue(isinstance(updated, dict)) - self.assertEqual(updated['EMAIL'], ['john@doe.org', 'jane@doe.org']) + self.assertEqual(updated["EMAIL"], ["john@doe.org", "jane@doe.org"]) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_ssh.py b/tests/test_mig_shared_ssh.py index 7aa4ba3ec..954bd8c1a 100644 --- a/tests/test_mig_shared_ssh.py +++ b/tests/test_mig_shared_ssh.py @@ -32,10 +32,18 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, testmain -from mig.shared.ssh import supported_pub_key_parsers, parse_pub_key, \ - generate_ssh_rsa_key_pair +from mig.shared.ssh import ( + generate_ssh_rsa_key_pair, + parse_pub_key, + supported_pub_key_parsers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + testmain, +) class MigSharedSsh(MigTestCase): @@ -45,11 +53,11 @@ def test_ssh_key_generate_and_parse(self): parsers = supported_pub_key_parsers() # NOTE: should return a non-empty dict of algos and parsers self.assertTrue(parsers) - self.assertTrue('ssh-rsa' in parsers) + self.assertTrue("ssh-rsa" in parsers) # Generate common sized keys and parse the result for keysize in (2048, 3072, 4096): - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize) + priv_key, pub_key = generate_ssh_rsa_key_pair(size=keysize) self.assertTrue(priv_key) self.assertTrue(pub_key) @@ -57,15 +65,16 @@ def test_ssh_key_generate_and_parse(self): try: parsed = parse_pub_key(pub_key) except ValueError as vae: - #print("Error in parsing pub key: %r" % vae) + # print("Error in parsing pub key: %r" % vae) parsed = None self.assertIsNotNone(parsed) - (priv_key, pub_key) = generate_ssh_rsa_key_pair(size=keysize, - encode_utf8=True) + priv_key, pub_key = generate_ssh_rsa_key_pair( + size=keysize, encode_utf8=True + ) self.assertTrue(priv_key) self.assertTrue(pub_key) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_tlsserver.py b/tests/test_mig_shared_tlsserver.py index b63881c61..3c5f1d062 100644 --- a/tests/test_mig_shared_tlsserver.py +++ b/tests/test_mig_shared_tlsserver.py @@ -30,11 +30,16 @@ import os import unittest +# Imports required for the unit test wrapping +from mig.shared.defaults import ( + STRONG_TLS_CIPHERS, + STRONG_TLS_CURVES, + STRONG_TLS_LEGACY_CIPHERS, +) + # Imports of the code under test from mig.shared.tlsserver import hardened_ssl_context, ssl -# Imports required for the unit test wrapping -from mig.shared.defaults import STRONG_TLS_CIPHERS, STRONG_TLS_LEGACY_CIPHERS, \ - STRONG_TLS_CURVES + # Imports required for the unit tests themselves from tests.support import MigTestCase @@ -136,14 +141,14 @@ class MigSharedTlsServer(MigTestCase): def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' + return "testconfig" # TODO: move the key+cert e.g. to data/X with helper function to load them? def _prepare_key_cert(self, key_path, cert_path): """Save key and cert file for use in real ssl tests""" - with open(key_path, 'w') as key_fd: + with open(key_path, "w") as key_fd: key_fd.write(TEST_HOST_KEY) - with open(cert_path, 'w') as cert_fd: + with open(cert_path, "w") as cert_fd: cert_fd.write(TEST_HOST_CERT) def before_each(self): @@ -164,21 +169,21 @@ def test_hardened_ssl_context_options_default(self): STRONG_TLS_CURVES, False, True, - False + False, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1', 0x4000000) | - getattr(ssl, 'OP_NO_TLSv1_1', 0x10000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1", 0x4000000) + | getattr(ssl, "OP_NO_TLSv1_1", 0x10000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context @@ -198,20 +203,20 @@ def test_hardened_ssl_context_options_tls1_1_only(self): STRONG_TLS_CURVES, True, False, - False + False, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1_2', 0x8000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1_2", 0x8000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context @@ -231,22 +236,22 @@ def test_hardened_ssl_context_options_tls1_3_only(self): STRONG_TLS_CURVES, False, False, - False + False, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1', 0x4000000) | - getattr(ssl, 'OP_NO_TLSv1_1', 0x10000000) | - getattr(ssl, 'OP_NO_TLSv1_2', 0x8000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1", 0x4000000) + | getattr(ssl, "OP_NO_TLSv1_1", 0x10000000) + | getattr(ssl, "OP_NO_TLSv1_2", 0x8000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context @@ -266,26 +271,27 @@ def test_hardened_ssl_context_options_fail_reneg(self): STRONG_TLS_CURVES, False, True, - True + True, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1', 0x4000000) | - getattr(ssl, 'OP_NO_TLSv1_1', 0x10000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1", 0x4000000) + | getattr(ssl, "OP_NO_TLSv1_1", 0x10000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context self.assertNotEqual( - context.options & expected_options, expected_options) + context.options & expected_options, expected_options + ) def test_hardened_ssl_context_options_fail_tls1_1(self): """Test SSL context options fail when different""" @@ -301,26 +307,27 @@ def test_hardened_ssl_context_options_fail_tls1_1(self): STRONG_TLS_CURVES, True, True, - False + False, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1', 0x4000000) | - getattr(ssl, 'OP_NO_TLSv1_1', 0x10000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1", 0x4000000) + | getattr(ssl, "OP_NO_TLSv1_1", 0x10000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context self.assertNotEqual( - context.options & expected_options, expected_options) + context.options & expected_options, expected_options + ) def test_hardened_ssl_context_options_fail_tls1_2(self): """Test SSL context options fail when different""" @@ -336,26 +343,27 @@ def test_hardened_ssl_context_options_fail_tls1_2(self): STRONG_TLS_CURVES, True, False, - False + False, ) # Verify options are set expected_options = ( - getattr(ssl, 'OP_NO_SSLv2', 0x1000000) | - getattr(ssl, 'OP_NO_SSLv3', 0x2000000) | - getattr(ssl, 'OP_NO_TLSv1', 0x4000000) | - getattr(ssl, 'OP_NO_TLSv1_1', 0x10000000) | - getattr(ssl, 'OP_NO_COMPRESSION', 0x20000) | - getattr(ssl, 'OP_CIPHER_SERVER_PREFERENCE', 0x400000) | - getattr(ssl, 'OP_SINGLE_ECDH_USE', 0x80000) | - getattr(ssl, 'OP_SINGLE_DH_USE', 0x100000) | - getattr(ssl, 'OP_NO_RENEGOTIATION', 0x40000000) | - getattr(ssl, 'OP_RENEGOTIATION', 0x40000000) + getattr(ssl, "OP_NO_SSLv2", 0x1000000) + | getattr(ssl, "OP_NO_SSLv3", 0x2000000) + | getattr(ssl, "OP_NO_TLSv1", 0x4000000) + | getattr(ssl, "OP_NO_TLSv1_1", 0x10000000) + | getattr(ssl, "OP_NO_COMPRESSION", 0x20000) + | getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0x400000) + | getattr(ssl, "OP_SINGLE_ECDH_USE", 0x80000) + | getattr(ssl, "OP_SINGLE_DH_USE", 0x100000) + | getattr(ssl, "OP_NO_RENEGOTIATION", 0x40000000) + | getattr(ssl, "OP_RENEGOTIATION", 0x40000000) ) # Verify the options were OR'd into the context self.assertNotEqual( - context.options & expected_options, expected_options) + context.options & expected_options, expected_options + ) def test_hardened_ssl_context_ciphers(self): """Test SSL context ciphers are set correctly""" @@ -371,12 +379,12 @@ def test_hardened_ssl_context_ciphers(self): STRONG_TLS_CURVES, False, True, - False + False, ) # NOTE: this may be too platform specific expected_start = "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:" expected_end = ":DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384" - result = ':'.join([spec['name'] for spec in context.get_ciphers()]) + result = ":".join([spec["name"] for spec in context.get_ciphers()]) self.assertTrue(result.startswith(expected_start)) self.assertTrue(result.endswith(expected_end)) @@ -394,11 +402,11 @@ def test_hardened_ssl_context_legacy_ciphers(self): STRONG_TLS_CURVES, False, True, - False + False, ) # NOTE: this may be too platform specific expected_start = "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:" expected_end = ":CAMELLIA256-SHA256:CAMELLIA128-SHA256" - result = ':'.join([spec['name'] for spec in context.get_ciphers()]) + result = ":".join([spec["name"] for spec in context.get_ciphers()]) self.assertTrue(result.startswith(expected_start)) self.assertTrue(result.endswith(expected_end)) diff --git a/tests/test_mig_shared_transferfunctions.py b/tests/test_mig_shared_transferfunctions.py index fd8c7d5cc..53665259a 100644 --- a/tests/test_mig_shared_transferfunctions.py +++ b/tests/test_mig_shared_transferfunctions.py @@ -31,16 +31,27 @@ import sys import tempfile -from tests.support import TEST_OUTPUT_DIR, MigTestCase, FakeConfiguration, \ - cleanpath, temppath, testmain -from mig.shared.transferfunctions import get_transfers_path, \ - load_data_transfers, create_data_transfer, delete_data_transfer, \ - lock_data_transfers, unlock_data_transfers +from mig.shared.transferfunctions import ( + create_data_transfer, + delete_data_transfer, + get_transfers_path, + load_data_transfers, + lock_data_transfers, + unlock_data_transfers, +) +from tests.support import ( + TEST_OUTPUT_DIR, + FakeConfiguration, + MigTestCase, + cleanpath, + temppath, + testmain, +) DUMMY_USER = "dummy-user" DUMMY_ID = "dummy-id" -DUMMY_HOME_DIR = 'dummy_user_home' -DUMMY_SETTINGS_DIR = 'dummy_user_settings' +DUMMY_HOME_DIR = "dummy_user_home" +DUMMY_SETTINGS_DIR = "dummy_user_settings" def noop(*args, **kwargs): @@ -55,13 +66,15 @@ class MigSharedTransferfunctions(MigTestCase): def before_each(self): test_user_home = temppath(DUMMY_HOME_DIR, self, ensure_dir=True) test_user_settings = cleanpath( - DUMMY_SETTINGS_DIR, self, ensure_dir=True) + DUMMY_SETTINGS_DIR, self, ensure_dir=True + ) # make two requisite root folders for the dummy user os.mkdir(os.path.join(test_user_home, DUMMY_USER)) os.mkdir(os.path.join(test_user_settings, DUMMY_USER)) # now create a configuration - self.dummy_conf = FakeConfiguration(user_home=test_user_home, - user_settings=test_user_settings) + self.dummy_conf = FakeConfiguration( + user_home=test_user_home, user_settings=test_user_settings + ) def test_transfers_basic_locking_shared(self): dummy_conf = self.dummy_conf @@ -82,9 +95,11 @@ def test_transfers_basic_locking_ro_to_rw_exclusive(self): # Non-blocking exclusive locking of shared lock must fail ro_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) rw_lock = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(ro_lock) self.assertFalse(rw_lock) @@ -99,9 +114,11 @@ def test_transfers_basic_locking_exclusive(self): rw_lock = lock_data_transfers(transfers_path, exclusive=True) # Non-blocking repeated shared or exclusive locking must fail ro_lock_again = lock_data_transfers( - transfers_path, exclusive=False, blocking=False) + transfers_path, exclusive=False, blocking=False + ) rw_lock_again = lock_data_transfers( - transfers_path, exclusive=True, blocking=False) + transfers_path, exclusive=True, blocking=False + ) self.assertTrue(rw_lock) self.assertFalse(ro_lock_again) @@ -112,18 +129,19 @@ def test_transfers_basic_locking_exclusive(self): def test_create_and_delete_transfer(self): dummy_conf = self.dummy_conf - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID} + ) self.assertTrue(success and DUMMY_ID in out) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None)) - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) + success, out = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID) self.assertTrue(success and out == DUMMY_ID) - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and transfers.get(DUMMY_ID, None) is None) @@ -131,38 +149,48 @@ def test_transfers_shared_read_locking(self): dummy_conf = self.dummy_conf transfers_path = get_transfers_path(dummy_conf, DUMMY_USER) # Init a dummy transfer to read and delete later - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) # take a shared ro lock up front ro_lock = lock_data_transfers(transfers_path, exclusive=False) # cases: - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete with repeated locking should fail - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True, blocking=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True, blocking=False + ) self.assertFalse(success) # Verify unchanged - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER) + success, transfers = load_data_transfers(dummy_conf, DUMMY_USER) self.assertTrue(success and DUMMY_ID in transfers) # Unlock all to leave critical section and allow clean up unlock_data_transfers(ro_lock) # Delete with locking should be fine again - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=True) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=True + ) self.assertTrue(success and out == DUMMY_ID) def test_transfers_exclusive_write_locking(self): @@ -174,40 +202,51 @@ def test_transfers_exclusive_write_locking(self): # cases: # Non-blocking load with repeated locking should fail - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=True, blocking=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=True, blocking=False + ) self.assertFalse(success) # Load without repeated locking should be fine - (success, transfers) = load_data_transfers(dummy_conf, DUMMY_USER, - do_lock=False) + success, transfers = load_data_transfers( + dummy_conf, DUMMY_USER, do_lock=False + ) self.assertTrue(success) # Non-blocking create with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Create without repeated locking should be fine - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=False) + success, out = create_data_transfer( + dummy_conf, DUMMY_USER, {"transfer_id": DUMMY_ID}, do_lock=False + ) self.assertTrue(success) # Non-blocking delete with repeated locking should fail - (success, out) = create_data_transfer(dummy_conf, DUMMY_USER, - {'transfer_id': DUMMY_ID}, - do_lock=True, blocking=False) + success, out = create_data_transfer( + dummy_conf, + DUMMY_USER, + {"transfer_id": DUMMY_ID}, + do_lock=True, + blocking=False, + ) self.assertFalse(success) # Delete without repeated locking should be fine - (success, out) = delete_data_transfer(dummy_conf, DUMMY_USER, DUMMY_ID, - do_lock=False) + success, out = delete_data_transfer( + dummy_conf, DUMMY_USER, DUMMY_ID, do_lock=False + ) self.assertTrue(success) unlock_data_transfers(rw_lock) -if __name__ == '__main__': +if __name__ == "__main__": testmain(failfast=True) diff --git a/tests/test_mig_shared_url.py b/tests/test_mig_shared_url.py index 1f029775d..b81cf9799 100644 --- a/tests/test_mig_shared_url.py +++ b/tests/test_mig_shared_url.py @@ -32,8 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) -from tests.support import MigTestCase, FakeConfiguration, testmain from mig.shared.url import _get_site_urls, check_local_site_url +from tests.support import FakeConfiguration, MigTestCase, testmain def _generate_dynamic_site_urls(url_base_list): @@ -42,40 +42,58 @@ def _generate_dynamic_site_urls(url_base_list): """ site_urls = [] for url_base in url_base_list: - site_urls += ['%s' % url_base, '%s/' % url_base, - '%s/wsgi-bin/home.py' % url_base, - '%s/wsgi-bin/logout.py' % url_base, - '%s/wsgi-bin/logout.py?return_url=' % url_base, - '%s/wsgi-bin/logout.py?return_url=%s' % (url_base, - ENC_URL) - ] + site_urls += [ + "%s" % url_base, + "%s/" % url_base, + "%s/wsgi-bin/home.py" % url_base, + "%s/wsgi-bin/logout.py" % url_base, + "%s/wsgi-bin/logout.py?return_url=" % url_base, + "%s/wsgi-bin/logout.py?return_url=%s" % (url_base, ENC_URL), + ] return site_urls -DUMMY_CONF = FakeConfiguration(migserver_http_url='http://myfqdn.org', - migserver_https_url='https://myfqdn.org', - migserver_https_mig_cert_url='', - migserver_https_ext_cert_url='', - migserver_https_mig_oid_url='', - migserver_https_ext_oid_url='', - migserver_https_mig_oidc_url='', - migserver_https_ext_oidc_url='', - migserver_https_sid_url='', - migserver_public_url='', - migserver_public_alias_url='') -ENC_URL = 'https%3A%2F%2Fsomewhere.org%2Fsub%0A' +DUMMY_CONF = FakeConfiguration( + migserver_http_url="http://myfqdn.org", + migserver_https_url="https://myfqdn.org", + migserver_https_mig_cert_url="", + migserver_https_ext_cert_url="", + migserver_https_mig_oid_url="", + migserver_https_ext_oid_url="", + migserver_https_mig_oidc_url="", + migserver_https_ext_oidc_url="", + migserver_https_sid_url="", + migserver_public_url="", + migserver_public_alias_url="", +) +ENC_URL = "https%3A%2F%2Fsomewhere.org%2Fsub%0A" # Static site-anchored and dynamic full URLs for local site URL check success -LOCAL_SITE_URLS = ['', 'abc', 'abc.txt', '/', '/bla', '/bla#anchor', - '/bla/', '/bla/#anchor', '/bla/bla', '/bla/bla/bla', - '//bla//', './bla', './bla/', './bla/bla', - './bla/bla/bla', 'logout.py', 'logout.py?bla=', - '/cgi-sid/logout.py', '/cgi-sid/logout.py?bla=bla', - '/cgi-sid/logout.py?return_url=%s' % ENC_URL, - ] +LOCAL_SITE_URLS = [ + "", + "abc", + "abc.txt", + "/", + "/bla", + "/bla#anchor", + "/bla/", + "/bla/#anchor", + "/bla/bla", + "/bla/bla/bla", + "//bla//", + "./bla", + "./bla/", + "./bla/bla", + "./bla/bla/bla", + "logout.py", + "logout.py?bla=", + "/cgi-sid/logout.py", + "/cgi-sid/logout.py?bla=bla", + "/cgi-sid/logout.py?return_url=%s" % ENC_URL, +] LOCAL_BASE_URLS = _get_site_urls(DUMMY_CONF) LOCAL_SITE_URLS += _generate_dynamic_site_urls(LOCAL_SITE_URLS) # Dynamic full URLs for local site URL check failure -REMOTE_BASE_URLS = ['https://someevilsite.com', 'ftp://someevilsite.com'] +REMOTE_BASE_URLS = ["https://someevilsite.com", "ftp://someevilsite.com"] REMOTE_SITE_URLS = _generate_dynamic_site_urls(REMOTE_BASE_URLS) @@ -85,15 +103,19 @@ class BasicUrl(MigTestCase): def test_valid_local_site_urls(self): """Check known valid static and dynamic URLs""" for url in LOCAL_SITE_URLS: - self.assertTrue(check_local_site_url(DUMMY_CONF, url), - "Local site url should succeed for %s" % url) + self.assertTrue( + check_local_site_url(DUMMY_CONF, url), + "Local site url should succeed for %s" % url, + ) def test_invalid_local_site_urls(self): """Check known invalid URLs""" for url in REMOTE_SITE_URLS: - self.assertFalse(check_local_site_url(DUMMY_CONF, url), - "Local site url should fail for %s" % url) + self.assertFalse( + check_local_site_url(DUMMY_CONF, url), + "Local site url should fail for %s" % url, + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userdb.py b/tests/test_mig_shared_userdb.py index 2efc524fc..083f14092 100644 --- a/tests/test_mig_shared_userdb.py +++ b/tests/test_mig_shared_userdb.py @@ -34,17 +34,26 @@ # Imports required for the unit test wrapping from mig.shared.base import distinguished_name_to_user from mig.shared.fileio import delete_file -from mig.shared.serial import loads, dumps +from mig.shared.serial import dumps, loads + # Imports of the code under test -from mig.shared.userdb import default_db_path, load_user_db, load_user_dict, \ - lock_user_db, save_user_db, save_user_dict, unlock_user_db, \ - update_user_dict +from mig.shared.userdb import ( + default_db_path, + load_user_db, + load_user_dict, + lock_user_db, + save_user_db, + save_user_dict, + unlock_user_db, + update_user_dict, +) + # Imports required for the unit tests themselves from tests.support import MigTestCase, ensure_dirs_exist, testmain -TEST_USER_ID = '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com' -THIS_USER_ID = '/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org' -OTHER_USER_ID = '/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org' +TEST_USER_ID = "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test User/emailAddress=test@example.com" +THIS_USER_ID = "/C=DK/ST=NA/L=NA/O=Local Org/OU=NA/CN=This User/emailAddress=this.user@here.org" +OTHER_USER_ID = "/C=DK/ST=NA/L=NA/O=Other Org/OU=NA/CN=Other User/emailAddress=other.user@there.org" class TestMigSharedUserDB(MigTestCase): @@ -52,7 +61,7 @@ class TestMigSharedUserDB(MigTestCase): def _provide_configuration(self): """Get test configuration""" - return 'testconfig' + return "testconfig" # Helper methods def _generate_sample_db(self, content=None): @@ -60,7 +69,7 @@ def _generate_sample_db(self, content=None): if content is None: sample_db = { TEST_USER_ID: distinguished_name_to_user(TEST_USER_ID), - THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID) + THIS_USER_ID: distinguished_name_to_user(THIS_USER_ID), } else: sample_db = content @@ -78,10 +87,12 @@ def before_each(self): """Set up test configuration and reset user DB paths""" ensure_dirs_exist(self.configuration.user_db_home) ensure_dirs_exist(self.configuration.mig_server_home) - self.user_db_path = os.path.join(self.configuration.user_db_home, - "MiG-users.db") - self.legacy_db_path = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.user_db_path = os.path.join( + self.configuration.user_db_home, "MiG-users.db" + ) + self.legacy_db_path = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) # Clear any existing test DBs if os.path.exists(self.user_db_path): @@ -95,15 +106,15 @@ def before_each(self): def test_default_db_path(self): """Test default_db_path returns correct path structure""" - expected = os.path.join(self.configuration.user_db_home, - "MiG-users.db") + expected = os.path.join(self.configuration.user_db_home, "MiG-users.db") result = default_db_path(self.configuration) self.assertEqual(result, expected) # Test legacy path fallback - self.configuration.user_db_home = '/no-such-dir' - expected_legacy = os.path.join(self.configuration.mig_server_home, - "MiG-users.db") + self.configuration.user_db_home = "/no-such-dir" + expected_legacy = os.path.join( + self.configuration.mig_server_home, "MiG-users.db" + ) result = default_db_path(self.configuration) self.assertEqual(result, expected_legacy) @@ -149,8 +160,7 @@ def test_load_user_db_direct(self): def test_load_user_db_missing(self): """Test loading missing user database""" - db_path = os.path.join( - self.configuration.user_db_home, "no-such-db.db") + db_path = os.path.join(self.configuration.user_db_home, "no-such-db.db") try: loaded = load_user_db(db_path) except Exception: @@ -190,8 +200,9 @@ def test_load_user_dict_missing(self): """Test loading non-existent user from DB""" self._create_sample_db() try: - loaded = load_user_dict(self.logger, "no-such-user", - self.user_db_path) + loaded = load_user_dict( + self.logger, "no-such-user", self.user_db_path + ) except Exception: loaded = None self.assertIsNone(loaded) @@ -200,8 +211,9 @@ def test_load_user_dict_existing(self): """Test loading existing user from DB""" sample_db = self._create_sample_db() try: - test_user_data = load_user_dict(self.logger, TEST_USER_ID, - self.user_db_path) + test_user_data = load_user_dict( + self.logger, TEST_USER_ID, self.user_db_path + ) except Exception: test_user_data = None self.assertEqual(test_user_data, sample_db[TEST_USER_ID]) @@ -209,8 +221,9 @@ def test_load_user_dict_existing(self): def test_save_user_dict_new_user(self): """Test saving new user to database""" other_user = distinguished_name_to_user(OTHER_USER_ID) - save_status = save_user_dict(self.logger, OTHER_USER_ID, - other_user, self.user_db_path) + save_status = save_user_dict( + self.logger, OTHER_USER_ID, other_user, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -223,8 +236,9 @@ def test_save_user_dict_update(self): sample_db = self._create_sample_db() changed = distinguished_name_to_user(THIS_USER_ID) changed.update({"Organization": "UPDATED", "new_field": "ADDED"}) - save_status = save_user_dict(self.logger, THIS_USER_ID, - changed, self.user_db_path) + save_status = save_user_dict( + self.logger, THIS_USER_ID, changed, self.user_db_path + ) self.assertTrue(save_status) with open(self.user_db_path, "rb") as fh: @@ -235,9 +249,12 @@ def test_save_user_dict_update(self): def test_update_user_dict(self): """Test update_user_dict with partial changes""" sample_db = self._create_sample_db() - updated = update_user_dict(self.logger, THIS_USER_ID, - {"Organization": "CHANGED"}, - self.user_db_path) + updated = update_user_dict( + self.logger, + THIS_USER_ID, + {"Organization": "CHANGED"}, + self.user_db_path, + ) self.assertEqual(updated["Organization"], "CHANGED") with open(self.user_db_path, "rb") as fh: @@ -249,8 +266,12 @@ def test_update_user_dict_requirements(self): """Test update_user_dict with invalid user ID""" self.logger.forgive_errors() try: - result = update_user_dict(self.logger, "no-such-user", - {"field": "test"}, self.user_db_path) + result = update_user_dict( + self.logger, + "no-such-user", + {"field": "test"}, + self.user_db_path, + ) except Exception: result = None self.assertIsNone(result) @@ -274,6 +295,7 @@ def delayed_load(): return loaded import threading + delayed_thread = threading.Thread(target=delayed_load) delayed_thread.start() time.sleep(0.2) @@ -308,7 +330,8 @@ def test_load_user_db_pickle_abi(self): def test_lock_user_db_invalid_path(self): """Test locking on non-existent database path""" invalid_path = os.path.join( - self.configuration.user_db_home, "missing", "MiG-users.db") + self.configuration.user_db_home, "missing", "MiG-users.db" + ) flock = lock_user_db(invalid_path) self.assertIsNone(flock) @@ -322,7 +345,7 @@ def test_unlock_user_db_invalid(self): def test_load_user_db_corrupted(self): """Test loading corrupted user database""" - with open(self.user_db_path, 'w') as fh: + with open(self.user_db_path, "w") as fh: fh.write("invalid pickle content") with self.assertRaises(Exception): load_user_db(self.user_db_path) @@ -354,8 +377,9 @@ def test_save_user_dict_invalid_id(self): """Test saving user with invalid characters in ID""" invalid_id = "../../invalid.user" user_dict = distinguished_name_to_user(TEST_USER_ID) - save_status = save_user_dict(self.logger, invalid_id, - user_dict, self.user_db_path) + save_status = save_user_dict( + self.logger, invalid_id, user_dict, self.user_db_path + ) self.assertFalse(save_status) def test_update_user_dict_empty_changes(self): @@ -364,8 +388,9 @@ def test_update_user_dict_empty_changes(self): self.logger.forgive_errors() sample_db = self._create_sample_db() original = sample_db[THIS_USER_ID].copy() - updated = update_user_dict(self.logger, THIS_USER_ID, {}, - self.user_db_path) + updated = update_user_dict( + self.logger, THIS_USER_ID, {}, self.user_db_path + ) self.assertEqual(updated, original) # TODO: adjust API to allow enabling the next test @@ -395,5 +420,5 @@ def test_load_user_db_allows_concurrent_read_access(self): unlock_user_db(flock) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_shared_userio.py b/tests/test_mig_shared_userio.py index 51310d8f3..932770a9d 100644 --- a/tests/test_mig_shared_userio.py +++ b/tests/test_mig_shared_userio.py @@ -523,7 +523,7 @@ class TestMigSharedUserIO__legacy_main(MigTestCase): """Legacy tests for corresponding module self-checks""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" # TODO: migrate all legacy self-check functionality into the above? def test_existing_main(self): @@ -538,14 +538,21 @@ def raise_on_error_exit(exit_code): else: identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): """Helper to show last print on error""" raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit, _argv=[]) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + _argv=[], + ) if __name__ == "__main__": diff --git a/tests/test_mig_shared_vgrid.py b/tests/test_mig_shared_vgrid.py index 91afcfa07..613d3d812 100644 --- a/tests/test_mig_shared_vgrid.py +++ b/tests/test_mig_shared_vgrid.py @@ -1419,7 +1419,7 @@ class TestMigSharedVgrid__legacy_main(MigTestCase): """Unit tests for legacy vgrid self-checks""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_existing_main(self): """Run the legacy self-tests directly in module""" @@ -1433,14 +1433,20 @@ def raise_on_error_exit(exit_code): else: identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) + raise_on_error_exit.last_print = None def record_last_print(value): """Keep track of printed output""" raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + ) if __name__ == "__main__": diff --git a/tests/test_mig_shared_vgridaccess.py b/tests/test_mig_shared_vgridaccess.py index 1dbe2a086..36dbb1704 100644 --- a/tests/test_mig_shared_vgridaccess.py +++ b/tests/test_mig_shared_vgridaccess.py @@ -35,92 +35,153 @@ import mig.shared.vgridaccess as vgridaccess from mig.shared.fileio import pickle, read_file from mig.shared.vgrid import vgrid_list, vgrid_set_entities, vgrid_settings -from mig.shared.vgridaccess import CONF, MEMBERS, OWNERS, RESOURCES, SETTINGS, \ - USERID, USERS, VGRIDS, check_resources_modified, check_vgrid_access, \ - check_vgrids_modified, fill_placeholder_cache, force_update_resource_map, \ - force_update_user_map, force_update_vgrid_map, get_re_provider_map, \ - get_resource_map, get_user_map, get_vgrid_map, get_vgrid_map_vgrids, \ - is_vgrid_parent_placeholder, load_resource_map, load_user_map, \ - load_vgrid_map, mark_vgrid_modified, refresh_resource_map, \ - refresh_user_map, refresh_vgrid_map, res_vgrid_access, \ - reset_resources_modified, reset_vgrids_modified, resources_using_re, \ - unmap_inheritance, unmap_resource, unmap_vgrid, user_allowed_res_confs, \ - user_allowed_res_exes, user_allowed_res_stores, user_allowed_res_units, \ - user_allowed_user_confs, user_owned_res_exes, user_owned_res_stores, \ - user_vgrid_access, user_visible_res_confs, user_visible_res_exes, \ - user_visible_res_stores, user_visible_user_confs, vgrid_inherit_map +from mig.shared.vgridaccess import ( + CONF, + MEMBERS, + OWNERS, + RESOURCES, + SETTINGS, + USERID, + USERS, + VGRIDS, + check_resources_modified, + check_vgrid_access, + check_vgrids_modified, + fill_placeholder_cache, + force_update_resource_map, + force_update_user_map, + force_update_vgrid_map, + get_re_provider_map, + get_resource_map, + get_user_map, + get_vgrid_map, + get_vgrid_map_vgrids, + is_vgrid_parent_placeholder, + load_resource_map, + load_user_map, + load_vgrid_map, + mark_vgrid_modified, + refresh_resource_map, + refresh_user_map, + refresh_vgrid_map, + res_vgrid_access, + reset_resources_modified, + reset_vgrids_modified, + resources_using_re, + unmap_inheritance, + unmap_resource, + unmap_vgrid, + user_allowed_res_confs, + user_allowed_res_exes, + user_allowed_res_stores, + user_allowed_res_units, + user_allowed_user_confs, + user_owned_res_exes, + user_owned_res_stores, + user_vgrid_access, + user_visible_res_confs, + user_visible_res_exes, + user_visible_res_stores, + user_visible_user_confs, + vgrid_inherit_map, +) from tests.support import MigTestCase, ensure_dirs_exist, testmain -from tests.support.usersupp import UserAssertMixin, TEST_USER_DN +from tests.support.usersupp import TEST_USER_DN, UserAssertMixin class TestMigSharedVgridAccess(MigTestCase, UserAssertMixin): """Unit tests for vgridaccess related helper functions""" - TEST_OWNER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/'\ - 'emailAddress=owner@example.org' - TEST_MEMBER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/'\ - 'emailAddress=member@example.org' - TEST_OUTSIDER_DN = \ - '/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/'\ - 'emailAddress=outsider@example.com' - TEST_RESOURCE_ID = 'test.example.org.0' - TEST_VGRID_NAME = 'testvgrid' - - TEST_OWNER_UUID = 'ff326a2b984828d9b32077c9b0b35a05' - TEST_MEMBER_UUID = 'ea9aedcbe69db279ca3676f83de94669' - TEST_RESOURCE_ALIAS = '0835f310d6422c36e33eeb7d0d3e9cf5' + TEST_OWNER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Owner/" + "emailAddress=owner@example.org" + ) + TEST_MEMBER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Member/" + "emailAddress=member@example.org" + ) + TEST_OUTSIDER_DN = ( + "/C=DK/ST=NA/L=NA/O=Test Org/OU=NA/CN=Test Outsider/" + "emailAddress=outsider@example.com" + ) + TEST_RESOURCE_ID = "test.example.org.0" + TEST_VGRID_NAME = "testvgrid" + + TEST_OWNER_UUID = "ff326a2b984828d9b32077c9b0b35a05" + TEST_MEMBER_UUID = "ea9aedcbe69db279ca3676f83de94669" + TEST_RESOURCE_ALIAS = "0835f310d6422c36e33eeb7d0d3e9cf5" # Default vgrid is initially set up without settings when force loaded - MINIMAL_VGRIDS = {'Generic': {OWNERS: [], MEMBERS: [], RESOURCES: [], - SETTINGS: []}} + MINIMAL_VGRIDS = { + "Generic": {OWNERS: [], MEMBERS: [], RESOURCES: [], SETTINGS: []} + } def _provide_configuration(self): """Prepare isolated test config""" - return 'testconfig' - - def _create_vgrid(self, vgrid_name, *, owners=None, members=None, - resources=None, settings=None, triggers=None): + return "testconfig" + + def _create_vgrid( + self, + vgrid_name, + *, + owners=None, + members=None, + resources=None, + settings=None, + triggers=None + ): """Helper to create valid skeleton vgrid for testing""" vgrid_path = os.path.join(self.configuration.vgrid_home, vgrid_name) ensure_dirs_exist(vgrid_path) # Save vgrid owners, members, resources, settings and triggers if owners is None: owners = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'owners', owners, allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, vgrid_name, "owners", owners, allow_empty=True + ) self.assertEqual(success_and_msg, (True, "")) if members is None: members = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'members', members, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, vgrid_name, "members", members, allow_empty=True + ) self.assertEqual(success_and_msg, (True, "")) if resources is None: resources = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'resources', resources, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "resources", + resources, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if settings is None: - settings = [('vgrid_name', vgrid_name)] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'settings', settings, - allow_empty=True) + settings = [("vgrid_name", vgrid_name)] + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "settings", + settings, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) if triggers is None: triggers = [] - success_and_msg = vgrid_set_entities(self.configuration, vgrid_name, - 'triggers', triggers, - allow_empty=True) + success_and_msg = vgrid_set_entities( + self.configuration, + vgrid_name, + "triggers", + triggers, + allow_empty=True, + ) self.assertEqual(success_and_msg, (True, "")) def _create_resource(self, res_name, owners, config=None): """Helper to create valid skeleton resource for testing""" res_path = os.path.join(self.configuration.resource_home, res_name) - res_owners_path = os.path.join(res_path, 'owners') - res_config_path = os.path.join(res_path, 'config') + res_owners_path = os.path.join(res_path, "owners") + res_config_path = os.path.join(res_path, "config") # Add resource skeleton with owners ensure_dirs_exist(res_path) if owners is None: @@ -129,10 +190,11 @@ def _create_resource(self, res_name, owners, config=None): self.assertTrue(saved) if config is None: # Make sure conf has one valid field - config = {'HOSTURL': res_name, - 'EXECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}], - 'STORECONFIG': [{'name': 'exe', 'vgrid': ['Generic']}] - } + config = { + "HOSTURL": res_name, + "EXECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + "STORECONFIG": [{"name": "exe", "vgrid": ["Generic"]}], + } saved = pickle(config, res_config_path, self.logger) self.assertTrue(saved) @@ -234,8 +296,10 @@ def test_force_update_vgrid_map(self): updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertTrue(updated_vgrid_map) - self.assertNotEqual(len(vgrid_map_before.get(VGRIDS, {})), - len(updated_vgrid_map.get(VGRIDS, {}))) + self.assertNotEqual( + len(vgrid_map_before.get(VGRIDS, {})), + len(updated_vgrid_map.get(VGRIDS, {})), + ) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) def test_refresh_user_map(self): @@ -339,7 +403,7 @@ def test_get_vgrid_map_vgrids(self): vgrid_list = get_vgrid_map_vgrids(self.configuration) self.assertTrue(isinstance(vgrid_list, list)) - self.assertEqual(['Generic'], vgrid_list) + self.assertEqual(["Generic"], vgrid_list) def test_user_owned_res_exes(self): """Test user_owned_res_exes returns owned execution nodes""" @@ -367,7 +431,8 @@ def test_user_allowed_res_units(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_units( - self.configuration, self.TEST_OWNER_DN, "exe") + self.configuration, self.TEST_OWNER_DN, "exe" + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -394,7 +459,8 @@ def test_user_allowed_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) allowed = user_allowed_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) @@ -419,23 +485,29 @@ def test_user_visible_res_stores(self): self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) visible = user_visible_res_stores( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(visible, dict)) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_allowed_user_confs(self): """Test user_allowed_user_confs returns allowed user confs""" - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) - - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) + + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_user_map(self.configuration) allowed = user_allowed_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertTrue(isinstance(allowed, dict)) self.assertIn(self.TEST_OWNER_UUID, allowed) self.assertIn(self.TEST_MEMBER_UUID, allowed) @@ -443,32 +515,33 @@ def test_user_allowed_user_confs(self): def test_fill_placeholder_cache(self): """Test fill_placeholder_cache populates cache""" cache = {} - fill_placeholder_cache(self.configuration, cache, [ - self.TEST_VGRID_NAME]) + fill_placeholder_cache( + self.configuration, cache, [self.TEST_VGRID_NAME] + ) self.assertIn(self.TEST_VGRID_NAME, cache) def test_is_vgrid_parent_placeholder(self): """Test is_vgrid_parent_placeholder detection""" - test_path = os.path.join(self.configuration.user_home, 'testvgrid') - result = is_vgrid_parent_placeholder(self.configuration, test_path, - test_path) + test_path = os.path.join(self.configuration.user_home, "testvgrid") + result = is_vgrid_parent_placeholder( + self.configuration, test_path, test_path + ) self.assertIsNone(result) def test_resources_using_re_notfound(self): """Test RE with no assigned resources returns empty list""" # Nonexistent RE should have no resources - res_list = resources_using_re(self.configuration, 'NoSuchRE') + res_list = resources_using_re(self.configuration, "NoSuchRE") self.assertEqual(res_list, []) def test_vgrid_inherit_map_single(self): """Test inheritance mapping with single vgrid""" - test_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] + test_settings = [("vgrid_name", self.TEST_VGRID_NAME), ("hidden", True)] test_map = { VGRIDS: { self.TEST_VGRID_NAME: { SETTINGS: test_settings, - OWNERS: [self.TEST_OWNER_DN] + OWNERS: [self.TEST_OWNER_DN], } } } @@ -477,13 +550,13 @@ def test_vgrid_inherit_map_single(self): self.assertIn(self.TEST_VGRID_NAME, vgrid_data) settings_dict = dict(vgrid_data[self.TEST_VGRID_NAME][SETTINGS]) self.assertIs(type(settings_dict), dict) - self.assertEqual(settings_dict.get('hidden'), True) + self.assertEqual(settings_dict.get("hidden"), True) # TODO: move these two modified tests to a test_mig_shared_modified.py def test_check_vgrids_modified_initial(self): """Verify initial modified vgrids list marks ALL and empty on reset""" modified, stamp = check_vgrids_modified(self.configuration) - self.assertEqual(modified, ['ALL']) + self.assertEqual(modified, ["ALL"]) reset_vgrids_modified(self.configuration) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, []) @@ -506,10 +579,9 @@ def test_user_vgrid_access(self): self._provision_test_user(self, TEST_USER_DN) # Start with global access to default vgrid - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) - self.assertIn('Generic', allowed_vgrids) + self.assertIn("Generic", allowed_vgrids) self.assertTrue(len(allowed_vgrids), 1) # Create private vgrid self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN]) @@ -517,20 +589,21 @@ def test_user_vgrid_access(self): initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) - allowed_vgrids = user_vgrid_access(self.configuration, - TEST_USER_DN) + allowed_vgrids = user_vgrid_access(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_res_vgrid_access(self): """Minimal test for resource vgrid participation""" # Only Generic access initially allowed_vgrids = res_vgrid_access( - self.configuration, self.TEST_RESOURCE_ID) - self.assertEqual(allowed_vgrids, ['Generic']) + self.configuration, self.TEST_RESOURCE_ID + ) + self.assertEqual(allowed_vgrids, ["Generic"]) # Add to vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, resources=[ - self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, resources=[self.TEST_RESOURCE_ID] + ) # Refresh maps to reflect new content initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -550,33 +623,40 @@ def test_vgrid_map_refresh(self): self._verify_vgrid_map_integrity(updated_vgrid_map) vgrids = updated_vgrid_map.get(VGRIDS, {}) self.assertIn(self.TEST_VGRID_NAME, vgrids) - self.assertEqual(vgrids[self.TEST_VGRID_NAME] - [OWNERS], [self.TEST_OWNER_DN]) + self.assertEqual( + vgrids[self.TEST_VGRID_NAME][OWNERS], [self.TEST_OWNER_DN] + ) def test_user_map_access(self): """Test user permissions through cached access maps""" # Add user as member - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Verify member access - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) def test_resource_map_update(self): """Verify resource visibility in cache""" # Check cached resource map does not yet contain entry - res_map_before, _ = load_resource_map(self.configuration, - caching=True) + res_map_before, _ = load_resource_map(self.configuration, caching=True) self.assertEqual(res_map_before, {}) # Add vgrid with assigned resource self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) @@ -594,11 +674,13 @@ def test_resource_map_update(self): def test_settings_inheritance(self): """Test inherited settings propagation through cached maps""" # Create top and sub vgrids with 'hidden' setting on top vgrid - top_settings = [('vgrid_name', self.TEST_VGRID_NAME), - ('hidden', True)] - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - settings=top_settings) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + top_settings = [("vgrid_name", self.TEST_VGRID_NAME), ("hidden", True)] + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + settings=top_settings, + ) + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -618,7 +700,7 @@ def test_settings_inheritance(self): self.assertTrue(top_settings_dict) # Verify hidden setting in cache - self.assertEqual(top_settings_dict.get('hidden'), True) + self.assertEqual(top_settings_dict.get("hidden"), True) # Retrieve sub vgrid settings from cached map sub_vgrid_data = vgrid_data.get(sub_vgrid, {}) @@ -626,10 +708,9 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting unset without inheritance - self.assertFalse(sub_settings_dict.get('hidden')) + self.assertFalse(sub_settings_dict.get("hidden")) - inherited_map = vgrid_inherit_map( - self.configuration, updated_vgrid_map) + inherited_map = vgrid_inherit_map(self.configuration, updated_vgrid_map) vgrid_data = inherited_map.get(VGRIDS, {}) self.assertTrue(vgrid_data) @@ -639,12 +720,12 @@ def test_settings_inheritance(self): sub_settings_dict = dict(sub_vgrid_data.get(SETTINGS, [])) # Verify hidden setting inheritance - self.assertEqual(sub_settings_dict.get('hidden'), True) + self.assertEqual(sub_settings_dict.get("hidden"), True) def test_unmap_inheritance(self): """Test unmap_inheritance clears inherited mappings""" self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN]) - sub_vgrid = os.path.join(self.TEST_VGRID_NAME, 'subvgrid') + sub_vgrid = os.path.join(self.TEST_VGRID_NAME, "subvgrid") self._create_vgrid(sub_vgrid) # Force refresh of cached map @@ -653,8 +734,9 @@ def test_unmap_inheritance(self): self.assertIn(self.TEST_VGRID_NAME, updated_vgrid_map.get(VGRIDS, {})) # Unmap and verify mark modified - unmap_inheritance(self.configuration, self.TEST_VGRID_NAME, - self.TEST_OWNER_DN) + unmap_inheritance( + self.configuration, self.TEST_VGRID_NAME, self.TEST_OWNER_DN + ) modified, stamp = check_vgrids_modified(self.configuration) self.assertEqual(modified, [self.TEST_VGRID_NAME, sub_vgrid]) @@ -662,8 +744,9 @@ def test_unmap_inheritance(self): def test_user_map_fields(self): """Verify user map includes complete profile/settings data""" # First add a couple of test users - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) # Force fresh user map initial_vgrid_map = force_update_vgrid_map(self.configuration) @@ -680,8 +763,11 @@ def test_resource_revoked_access(self): """Verify resource removal propagates through cached maps""" # First add resource and vgrid self._create_resource(self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) @@ -697,9 +783,14 @@ def test_resource_revoked_access(self): self.assertIn(self.TEST_RESOURCE_ID, initial_map) # Remove resource assignment from vgrid - success_and_msg = vgrid_set_entities(self.configuration, self.TEST_VGRID_NAME, - 'resources', [], allow_empty=True) - self.assertEqual(success_and_msg, (True, '')) + success_and_msg = vgrid_set_entities( + self.configuration, + self.TEST_VGRID_NAME, + "resources", + [], + allow_empty=True, + ) + self.assertEqual(success_and_msg, (True, "")) updated_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(updated_vgrid_map) @@ -717,9 +808,9 @@ def test_resource_revoked_access(self): def test_non_recursive_inheritance(self): """Verify non-recursive map excludes nested vgrids""" # Create parent+child vgrids - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') + child_vgrid = os.path.join(parent_vgrid, "child") self._create_vgrid(child_vgrid, members=[self.TEST_MEMBER_DN]) # Force update to avoid auto caching and get non-recursive map @@ -733,21 +824,25 @@ def test_non_recursive_inheritance(self): # Child should still appear when non-recursive but just not inherit self.assertIn(child_vgrid, vgrid_map.get(VGRIDS, {})) # Check owners and members to verify they aren't inherited - self.assertEqual(vgrid_map[VGRIDS][parent_vgrid][OWNERS], - [self.TEST_OWNER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][parent_vgrid][OWNERS], [self.TEST_OWNER_DN] + ) self.assertEqual(len(vgrid_map[VGRIDS][parent_vgrid][MEMBERS]), 0) self.assertEqual(len(vgrid_map[VGRIDS][child_vgrid][OWNERS]), 0) - self.assertEqual(vgrid_map[VGRIDS][child_vgrid][MEMBERS], - [self.TEST_MEMBER_DN]) + self.assertEqual( + vgrid_map[VGRIDS][child_vgrid][MEMBERS], [self.TEST_MEMBER_DN] + ) def test_hidden_setting_propagation(self): """Verify hidden=True propagates to not infect parent settings""" - parent_vgrid = 'parent' + parent_vgrid = "parent" self._create_vgrid(parent_vgrid, owners=[self.TEST_OWNER_DN]) - child_vgrid = os.path.join(parent_vgrid, 'child') - self._create_vgrid(child_vgrid, owners=[self.TEST_OWNER_DN], - settings=[('vgrid_name', child_vgrid), - ('hidden', True)]) + child_vgrid = os.path.join(parent_vgrid, "child") + self._create_vgrid( + child_vgrid, + owners=[self.TEST_OWNER_DN], + settings=[("vgrid_name", child_vgrid), ("hidden", True)], + ) # Verify parent remains visible in cache updated_vgrid_map = force_update_vgrid_map(self.configuration) @@ -756,64 +851,79 @@ def test_hidden_setting_propagation(self): self.assertIn(child_vgrid, updated_vgrid_map.get(VGRIDS, {})) parent_data = updated_vgrid_map.get(VGRIDS, {}).get(parent_vgrid, {}) parent_settings = dict(parent_data.get(SETTINGS, [])) - self.assertNotEqual(parent_settings.get('hidden'), True) + self.assertNotEqual(parent_settings.get("hidden"), True) def test_default_vgrid_access(self): """Verify special access rules for default vgrid""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Even non-member should have access to default vgrid - participant = check_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN, - 'Generic') + participant = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, "Generic" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) - self.assertIn('Generic', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) + self.assertIn("Generic", allowed_vgrids) # Invalid vgrid should not allow any participation or access - participant = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'invalid-vgrid-name') + participant = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "invalid-vgrid-name" + ) self.assertFalse(participant) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) - self.assertNotIn('invalid-vgrid-name', allowed_vgrids) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("invalid-vgrid-name", allowed_vgrids) def test_general_vgrid_access(self): """Verify general access rules for vgrids""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[self.TEST_OWNER_DN], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[self.TEST_OWNER_DN], + members=[self.TEST_MEMBER_DN], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Test vgrid must allow owner and members access - allowed = check_vgrid_access(self.configuration, self.TEST_OWNER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OWNER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OWNER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, self.TEST_VGRID_NAME + ) self.assertTrue(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_MEMBER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_MEMBER_DN + ) self.assertIn(self.TEST_VGRID_NAME, allowed_vgrids) # Test vgrid must reject allow outsider access - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) - allowed_vgrids = user_vgrid_access(self.configuration, - self.TEST_OUTSIDER_DN) + allowed_vgrids = user_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN + ) self.assertNotIn(self.TEST_VGRID_NAME, allowed_vgrids) def test_user_allowed_res_confs(self): @@ -821,44 +931,49 @@ def test_user_allowed_res_confs(self): # Create test user and add test resource to vgrid self._provision_test_user(self, TEST_USER_DN) self._create_resource(self.TEST_RESOURCE_ID, [TEST_USER_DN]) - self._create_vgrid(self.TEST_VGRID_NAME, owners=[TEST_USER_DN], - resources=[self.TEST_RESOURCE_ID]) + self._create_vgrid( + self.TEST_VGRID_NAME, + owners=[TEST_USER_DN], + resources=[self.TEST_RESOURCE_ID], + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) force_update_resource_map(self.configuration) # Owner should be allowed access - allowed = user_allowed_res_confs(self.configuration, - TEST_USER_DN) + allowed = user_allowed_res_confs(self.configuration, TEST_USER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, allowed) def test_user_visible_res_confs(self): """Minimal test for user_visible_res_confs""" # Owner should see owned resources even without vgrid access - self._create_resource(self.TEST_RESOURCE_ID, - owners=[self.TEST_OWNER_DN]) + self._create_resource( + self.TEST_RESOURCE_ID, owners=[self.TEST_OWNER_DN] + ) force_update_resource_map(self.configuration) - visible = user_visible_res_confs( - self.configuration, self.TEST_OWNER_DN) + visible = user_visible_res_confs(self.configuration, self.TEST_OWNER_DN) self.assertIn(self.TEST_RESOURCE_ALIAS, visible) def test_user_visible_user_confs(self): """Minimal test for user_visible_user_confs""" # Owners should see themselves in auto map # NOTE: use provision users to skip fixtures here - self._provision_test_users(self, self.TEST_OWNER_DN, - self.TEST_MEMBER_DN) + self._provision_test_users( + self, self.TEST_OWNER_DN, self.TEST_MEMBER_DN + ) force_update_user_map(self.configuration) visible = user_visible_user_confs( - self.configuration, self.TEST_OWNER_DN) + self.configuration, self.TEST_OWNER_DN + ) self.assertIn(self.TEST_OWNER_UUID, visible) def test_get_re_provider_map(self): """Test RE provider map includes test resource""" - test_re = 'Python' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/python/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Python" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/python/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Update maps to include new resource force_update_resource_map(self.configuration) @@ -870,10 +985,11 @@ def test_get_re_provider_map(self): def test_resources_using_re(self): """Test finding resources with specific runtime environment""" - test_re = 'Bash' - res_config = {'RUNTIMEENVIRONMENT': [(test_re, '/bash/path')]} - self._create_resource(self.TEST_RESOURCE_ID, [ - self.TEST_OWNER_DN], res_config) + test_re = "Bash" + res_config = {"RUNTIMEENVIRONMENT": [(test_re, "/bash/path")]} + self._create_resource( + self.TEST_RESOURCE_ID, [self.TEST_OWNER_DN], res_config + ) # Refresh resource map force_update_resource_map(self.configuration) @@ -909,28 +1025,32 @@ def test_unmap_resource(self): def test_access_nonexistent_vgrid(self): """Ensure checks fail cleanly for non-existent vgrid""" - allowed = check_vgrid_access(self.configuration, self.TEST_MEMBER_DN, - 'no-such-vgrid') + allowed = check_vgrid_access( + self.configuration, self.TEST_MEMBER_DN, "no-such-vgrid" + ) self.assertFalse(allowed) # Should not appear in allowed vgrids allowed_vgrids = user_vgrid_access( - self.configuration, self.TEST_MEMBER_DN) - self.assertNotIn('no-such-vgrid', allowed_vgrids) + self.configuration, self.TEST_MEMBER_DN + ) + self.assertNotIn("no-such-vgrid", allowed_vgrids) def test_empty_member_access(self): """Verify members-only vgrid rejects outsiders""" - self._create_vgrid(self.TEST_VGRID_NAME, owners=[], - members=[self.TEST_MEMBER_DN]) + self._create_vgrid( + self.TEST_VGRID_NAME, owners=[], members=[self.TEST_MEMBER_DN] + ) initial_vgrid_map = force_update_vgrid_map(self.configuration) self._verify_vgrid_map_integrity(initial_vgrid_map) self.assertIn(self.TEST_VGRID_NAME, initial_vgrid_map.get(VGRIDS, {})) # Outsider should be blocked despite no owners - allowed = check_vgrid_access(self.configuration, self.TEST_OUTSIDER_DN, - self.TEST_VGRID_NAME) + allowed = check_vgrid_access( + self.configuration, self.TEST_OUTSIDER_DN, self.TEST_VGRID_NAME + ) self.assertFalse(allowed) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_unittest_testcore.py b/tests/test_mig_unittest_testcore.py index 24a420c2c..c3e43629c 100644 --- a/tests/test_mig_unittest_testcore.py +++ b/tests/test_mig_unittest_testcore.py @@ -31,16 +31,15 @@ import os import sys -from tests.support import MigTestCase, testmain - from mig.unittest.testcore import legacy_main +from tests.support import MigTestCase, testmain class MigUnittestTestcore__legacy_main(MigTestCase): """Legacy tests for corresponding module self-checks""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_existing_main(self): """Run the legacy self-tests directly in module""" @@ -48,9 +47,10 @@ def test_existing_main(self): def raise_on_error_exit(exit_code, identifying_message=None): if exit_code != 0: if identifying_message is None: - identifying_message = 'unknown' + identifying_message = "unknown" raise AssertionError( - 'legacy test failure: %s' % (identifying_message,)) + "legacy test failure: %s" % (identifying_message,) + ) raise_on_error_exit.last_print = None @@ -58,8 +58,12 @@ def record_last_print(value): """Keep track of printed output""" raise_on_error_exit.last_print = value - legacy_main(self.configuration, print=record_last_print, _exit=raise_on_error_exit) + legacy_main( + self.configuration, + print=record_last_print, + _exit=raise_on_error_exit, + ) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_mig_wsgibin.py b/tests/test_mig_wsgibin.py index 1d0f9ecdb..3c487ad60 100644 --- a/tests/test_mig_wsgibin.py +++ b/tests/test_mig_wsgibin.py @@ -38,12 +38,24 @@ # Imports required for the unit test wrapping import mig.shared.returnvalues as returnvalues -from mig.shared.base import allow_script, brief_list, client_dir_id, \ - client_id_dir, get_short_id, invisible_path +from mig.shared.base import ( + allow_script, + brief_list, + client_dir_id, + client_id_dir, + get_short_id, + invisible_path, +) from mig.shared.compat import SimpleNamespace + # Imports required for the unit tests themselves -from tests.support import MIG_BASE, MigTestCase, ensure_dirs_exist, \ - is_path_within, testmain +from tests.support import ( + MIG_BASE, + MigTestCase, + ensure_dirs_exist, + is_path_within, + testmain, +) from tests.support.snapshotsupp import SnapshotAssertMixin from tests.support.wsgisupp import WsgiAssertMixin, prepare_wsgi @@ -60,12 +72,12 @@ def __init__(self): def handle_decl(self, decl): try: - decltag, decltype = decl.split(' ') + decltag, decltype = decl.split(" ") except Exception: decltag = "" decltype = "" - if decltag.upper() == 'DOCTYPE': + if decltag.upper() == "DOCTYPE": self._saw_doctype = True else: decltype = "unknown" @@ -73,11 +85,11 @@ def handle_decl(self, decl): self._doctype = decltype def handle_starttag(self, tag, attrs): - if tag == 'html': + if tag == "html": if self._saw_tags: - tag_html = 'not_first' + tag_html = "not_first" else: - tag_html = 'was_first' + tag_html = "was_first" self._tag_html = tag_html self._saw_tags = True @@ -85,13 +97,13 @@ def assert_basics(self): if not self._saw_doctype: raise AssertionError("missing DOCTYPE") - if self._doctype != 'html': + if self._doctype != "html": raise AssertionError("non-html DOCTYPE") - if self._tag_html == 'none': + if self._tag_html == "none": raise AssertionError("missing ") - if self._tag_html != 'was_first': + if self._tag_html != "was_first": raise AssertionError("first tag seen was not ") @@ -110,13 +122,13 @@ def handle_data(self, *args, **kwargs): def handle_starttag(self, tag, attrs): DocumentBasicsHtmlParser.handle_starttag(self, tag, attrs) - if tag == 'title': + if tag == "title": self._within_title = True def handle_endtag(self, tag): DocumentBasicsHtmlParser.handle_endtag(self, tag) - if tag == 'title': + if tag == "title": self._within_title = False def title(self, trim_newlines=False): @@ -138,7 +150,7 @@ def _import_forcibly(module_name, relative_module_dir=None): that resides within a non-module directory. """ - module_path = os.path.join(MIG_BASE, 'mig') + module_path = os.path.join(MIG_BASE, "mig") if relative_module_dir is not None: module_path = os.path.join(module_path, relative_module_dir) sys.path.append(module_path) @@ -148,7 +160,7 @@ def _import_forcibly(module_name, relative_module_dir=None): # Imports of the code under test (indirect import needed here) -migwsgi = _import_forcibly('migwsgi', relative_module_dir='wsgi-bin') +migwsgi = _import_forcibly("migwsgi", relative_module_dir="wsgi-bin") class FakeBackend: @@ -160,8 +172,8 @@ class FakeBackend: def __init__(self): self.output_objects = [ - {'object_type': 'start'}, - {'object_type': 'title', 'text': 'ERROR'}, + {"object_type": "start"}, + {"object_type": "title", "text": "ERROR"}, ] self.return_value = returnvalues.ERROR @@ -175,6 +187,7 @@ def set_response(self, output_objects, returnvalue): def to_import_module(self): def _import_module(module_path): return self + return _import_module @@ -182,11 +195,11 @@ class MigWsgibin(MigTestCase, SnapshotAssertMixin, WsgiAssertMixin): """WSGI glue test cases""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -208,51 +221,45 @@ def assertHtmlTitle(self, value, title_text=None, trim_newlines=False): def test_top_level_request_returns_status_ok(self): wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) def test_objects_containing_only_title_has_expected_title(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertHtmlTitle(output, title_text='TEST', trim_newlines=True) + self.assertHtmlTitle(output, title_text="TEST", trim_newlines=True) def test_objects_containing_only_title_matches_snapshot(self): - output_objects = [ - {'object_type': 'title', 'text': 'TEST'} - ] + output_objects = [{"object_type": "title", "text": "TEST"}] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) - self.assertSnapshot(output, extension='html') + self.assertSnapshot(output, extension="html") -class MigWsgibin_output_objects(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_output_objects( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for output_object related part of wsgi functions.""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/') + self.fake_wsgi = prepare_wsgi(self.configuration, "http://localhost/") self.application_args = ( self.fake_wsgi.environ, @@ -273,51 +280,46 @@ def test_unknown_object_type_generates_valid_error_page(self): self.logger.forgive_errors() output_objects = [ { - 'object_type': 'nonexistent', # trigger error handling path + "object_type": "nonexistent", # trigger error handling path } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) - output, _ = self.assertWsgiResponse( - wsgi_result, self.fake_wsgi, 200) + output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertIsValidHtmlDocument(output) def test_objects_with_type_text(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) self.assertSnapshotOfHtmlContent(output) -class MigWsgibin_input_object(MigTestCase, WsgiAssertMixin, - SnapshotAssertMixin): +class MigWsgibin_input_object( + MigTestCase, WsgiAssertMixin, SnapshotAssertMixin +): """Unit tests for input_object related part of wsgi functions.""" - DUMMY_BYTES = 'dummyæøå-ßßß-value'.encode('utf-8') + DUMMY_BYTES = "dummyæøå-ßßß-value".encode("utf-8") def _provide_configuration(self): - return 'testconfig' + return "testconfig" def before_each(self): self.fake_backend = FakeBackend() @@ -330,8 +332,9 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # Set up a wsgi input with non-ascii bytes and open it in binary mode # If form_overrides is passed a list of tuples like [('key' 'val')] it # produces a fake_wsgi input on the form: b'key=val' - self.fake_wsgi = prepare_wsgi(self.configuration, 'http://localhost/', - form=form_overrides) + self.fake_wsgi = prepare_wsgi( + self.configuration, "http://localhost/", form=form_overrides + ) # override the default environ fields from wsgisupp if custom_env: self.fake_wsgi.environ.update(custom_env) @@ -348,7 +351,7 @@ def _prepare_test(self, form_overrides=None, custom_env=None): # NOTE: enabled with underlying wsgi use of Fieldstorage fixed def test_put_text_plain_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -358,20 +361,16 @@ def test_put_text_plain_with_binary_input_succeeds(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input @@ -379,7 +378,7 @@ def test_put_text_plain_with_binary_input_succeeds(self): @unittest.skip("disabled with underlying wsgi use of Fieldstorage fixed") def test_put_text_plain_with_binary_input_fails(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = { "REQUEST_METHOD": "PUT", "CONTENT_TYPE": "text/plain", @@ -389,52 +388,44 @@ def test_put_text_plain_with_binary_input_fails(self): output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) # TODO: can we add assertLogs to check error log explicitly? wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must fail with HTTP 500 from failing to parse input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 500) def test_post_url_encoded_with_binary_input_succeeds(self): - test_form = [('_csrf', self.DUMMY_BYTES)] + test_form = [("_csrf", self.DUMMY_BYTES)] test_env = None self._prepare_test(test_form, test_env) output_objects = [ # workaround invalid HTML being generated with no title object + {"object_type": "title", "text": "TEST"}, { - 'object_type': 'title', - 'text': 'TEST' + "object_type": "text", + "text": "some text", }, - { - 'object_type': 'text', - 'text': 'some text', - } ] self.fake_backend.set_response(output_objects, returnvalues.OK) wsgi_result = migwsgi.application( - *self.application_args, - **self.application_kwargs + *self.application_args, **self.application_kwargs ) # Must succeed with HTTP 200 when it parses input output, _ = self.assertWsgiResponse(wsgi_result, self.fake_wsgi, 200) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_support.py b/tests/test_support.py index 56b74f594..b68c4249c 100644 --- a/tests/test_support.py +++ b/tests/test_support.py @@ -28,15 +28,21 @@ """Unit tests for the tests module pointed to in the filename""" from __future__ import print_function + import os import sys import unittest -from tests.support import MigTestCase, PY2, testmain, temppath, \ - AssertOver, FakeConfiguration - from mig.shared.conf import get_configuration_object from mig.shared.configuration import Configuration +from tests.support import ( + PY2, + AssertOver, + FakeConfiguration, + MigTestCase, + temppath, + testmain, +) class InstrumentedAssertOver(AssertOver): @@ -62,6 +68,7 @@ def to_check_callable(self): def _wrapped_check_callable(): self._check_callable_called = True _check_callable() + self._check_callable = _wrapped_check_callable return _wrapped_check_callable @@ -71,8 +78,8 @@ class SupportTestCase(MigTestCase): def _class_attribute(self, name, **kwargs): cls = type(self) - if 'value' in kwargs: - setattr(cls, name, kwargs['value']) + if "value" in kwargs: + setattr(cls, name, kwargs["value"]) else: return getattr(cls, name, None) @@ -80,15 +87,17 @@ def test_requires_requesting_a_configuration(self): with self.assertRaises(AssertionError) as raised: self.configuration theexception = raised.exception - self.assertEqual(str(theexception), - "configuration access but testcase did not request it") + self.assertEqual( + str(theexception), + "configuration access but testcase did not request it", + ) @unittest.skipIf(PY2, "Python 3 only") def test_unclosed_files_are_recorded(self): tmp_path = temppath("support-unclosed", self) def open_without_close(): - with open(tmp_path, 'w'): + with open(tmp_path, "w"): pass open(tmp_path) return @@ -112,11 +121,13 @@ def assert_is_int(value): assert isinstance(value, int) attempt_wrapper = self.assert_over( - values=(1, 2, 3), _AssertOver=InstrumentedAssertOver) + values=(1, 2, 3), _AssertOver=InstrumentedAssertOver + ) # record the wrapper on the test case so the subsequent test can assert against it - self._class_attribute('surviving_attempt_wrapper', - value=attempt_wrapper) + self._class_attribute( + "surviving_attempt_wrapper", value=attempt_wrapper + ) with attempt_wrapper as attempt: attempt(assert_is_int) @@ -124,14 +135,15 @@ def assert_is_int(value): self.assertTrue(attempt_wrapper.has_check_callable()) # cleanup was recorded - self.assertIn(attempt_wrapper.get_check_callable(), - self._cleanup_checks) + self.assertIn( + attempt_wrapper.get_check_callable(), self._cleanup_checks + ) def test_when_asserting_over_multiple_values_after(self): # test name is purposefully after ..._recorded in sort order # such that we can check the check function was called correctly - attempt_wrapper = self._class_attribute('surviving_attempt_wrapper') + attempt_wrapper = self._class_attribute("surviving_attempt_wrapper") self.assertTrue(attempt_wrapper.was_check_callable_called()) @@ -139,7 +151,7 @@ class SupportTestCase_using_fakeconfig(MigTestCase): """Coverage of a MiG Testcase hat requests a fakeconfig""" def _provide_configuration(self): - return 'fakeconfig' + return "fakeconfig" def test_provides_a_fake_configuration(self): configuration = self.configuration @@ -157,10 +169,10 @@ class SupportTestCase_using_testconfig(MigTestCase): """Coverage of a MiG Testcase that requests a testconfig""" def _provide_configuration(self): - return 'testconfig' + return "testconfig" def test_provides_the_test_configuration(self): - expected_last_dir = 'testconfs-py2' if PY2 else 'testconfs-py3' + expected_last_dir = "testconfs-py2" if PY2 else "testconfs-py3" configuration = self.configuration @@ -173,5 +185,5 @@ def test_provides_the_test_configuration(self): self.assertTrue(config_file_last_dir, expected_last_dir) -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_assertover.py b/tests/test_tests_support_assertover.py index 702a80e94..ae9be10b2 100644 --- a/tests/test_tests_support_assertover.py +++ b/tests/test_tests_support_assertover.py @@ -35,7 +35,7 @@ def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportAssertOver(unittest.TestCase): @@ -44,7 +44,9 @@ class TestsSupportAssertOver(unittest.TestCase): def test_none_failing(self): saw_raise = False try: - with AssertOver(values=('some thing', 'other thing')) as value_block: + with AssertOver( + values=("some thing", "other thing") + ) as value_block: value_block(lambda _: assert_a_thing(_)) except Exception as exc: saw_raise = True @@ -52,13 +54,18 @@ def test_none_failing(self): def test_three_total_two_failing(self): with self.assertRaises(AssertionError) as raised: - with AssertOver(values=('some thing', 'other stuff', 'foobar')) as value_block: + with AssertOver( + values=("some thing", "other stuff", "foobar") + ) as value_block: value_block(lambda _: assert_a_thing(_)) theexception = raised.exception - self.assertEqual(str(theexception), """assertions raised for the following values: + self.assertEqual( + str(theexception), + """assertions raised for the following values: - <'other stuff'> : must end with a thing -- <'foobar'> : must end with a thing""") +- <'foobar'> : must end with a thing""", + ) def test_no_cases(self): with self.assertRaises(AssertionError) as raised: @@ -69,5 +76,5 @@ def test_no_cases(self): self.assertIsInstance(theexception, NoCasesError) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_tests_support_configsupp.py b/tests/test_tests_support_configsupp.py index 0e85cd2c1..8d8d23e2d 100644 --- a/tests/test_tests_support_configsupp.py +++ b/tests/test_tests_support_configsupp.py @@ -27,11 +27,10 @@ """Unit tests for the tests module pointed to in the filename""" +from mig.shared.configuration import Configuration from tests.support import MigTestCase, testmain from tests.support.configsupp import FakeConfiguration -from mig.shared.configuration import Configuration - class TestsSupportConfigsupp_FakeConfiguration(MigTestCase): """Check some basic behaviours of FakeConfiguration instances.""" @@ -43,13 +42,13 @@ def test_consistent_parameters(self): self.maxDiff = None self.assertEqual( Configuration.to_dict(default_configuration), - Configuration.to_dict(fake_configuration) + Configuration.to_dict(fake_configuration), ) def test_only_configuration_keys(self): with self.assertRaises(AssertionError): - FakeConfiguration(bar='1') + FakeConfiguration(bar="1") -if __name__ == '__main__': +if __name__ == "__main__": testmain() diff --git a/tests/test_tests_support_wsgisupp.py b/tests/test_tests_support_wsgisupp.py index 4adcb975d..03e5346e4 100644 --- a/tests/test_tests_support_wsgisupp.py +++ b/tests/test_tests_support_wsgisupp.py @@ -28,15 +28,15 @@ """Unit tests for the tests module pointed to in the filename""" import unittest -from mig.shared.compat import SimpleNamespace +from mig.shared.compat import SimpleNamespace from tests.support import AssertOver from tests.support.wsgisupp import prepare_wsgi def assert_a_thing(value): """A simple assert helper to test with""" - assert value.endswith(' thing'), "must end with a thing" + assert value.endswith(" thing"), "must end with a thing" class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): @@ -44,56 +44,57 @@ class TestsSupportWsgisupp_prepare_wsgi(unittest.TestCase): def test_prepare_GET(self): configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, 'http://testhost/some/path') + environ, _ = prepare_wsgi(configuration, "http://testhost/some/path") - self.assertEqual(environ['MIG_CONF'], - '/path/to/the/confs/MiGserver.conf') - self.assertEqual(environ['HTTP_HOST'], 'testhost') - self.assertEqual(environ['PATH_INFO'], '/some/path') - self.assertEqual(environ['REQUEST_METHOD'], 'GET') + self.assertEqual( + environ["MIG_CONF"], "/path/to/the/confs/MiGserver.conf" + ) + self.assertEqual(environ["HTTP_HOST"], "testhost") + self.assertEqual(environ["PATH_INFO"], "/some/path") + self.assertEqual(environ["REQUEST_METHOD"], "GET") def test_prepare_GET_with_query(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, query={ - 'foo': 'true', - 'bar': 1 - }) + environ, _ = prepare_wsgi( + configuration, test_url, query={"foo": "true", "bar": 1} + ) - self.assertEqual(environ['QUERY_STRING'], 'foo=true&bar=1') + self.assertEqual(environ["QUERY_STRING"], "foo=true&bar=1") def test_prepare_POST(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) - environ, _ = prepare_wsgi(configuration, test_url, method='POST') + environ, _ = prepare_wsgi(configuration, test_url, method="POST") - self.assertEqual(environ['REQUEST_METHOD'], 'POST') + self.assertEqual(environ["REQUEST_METHOD"], "POST") def test_prepare_POST_with_headers(self): - test_url = 'http://testhost/some/path' + test_url = "http://testhost/some/path" configuration = SimpleNamespace( - config_file='/path/to/the/confs/MiGserver.conf' + config_file="/path/to/the/confs/MiGserver.conf" ) headers = { - 'Authorization': 'Basic XXXX', - 'Content-Length': 0, + "Authorization": "Basic XXXX", + "Content-Length": 0, } environ, _ = prepare_wsgi( - configuration, test_url, method='POST', headers=headers) + configuration, test_url, method="POST", headers=headers + ) - self.assertEqual(environ['CONTENT_LENGTH'], 0) - self.assertEqual(environ['HTTP_AUTHORIZATION'], 'Basic XXXX') + self.assertEqual(environ["CONTENT_LENGTH"], 0) + self.assertEqual(environ["HTTP_AUTHORIZATION"], "Basic XXXX") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()