diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b8d9ca6866..1a70118780 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -722,8 +722,11 @@ def _create_or_update_code_dir( # Validate dependency path for security _validate_dependency_path(dependency) lib_dir = os.path.join(code_dir, "lib") + dep_basename = os.path.basename(dependency) + if not dep_basename or dep_basename in (".", ".."): + raise ValueError(f"Invalid dependency path: {dependency}") if os.path.isdir(dependency): - shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) + shutil.copytree(dependency, os.path.join(lib_dir, dep_basename)) else: if not os.path.exists(lib_dir): os.mkdir(lib_dir) @@ -739,6 +742,14 @@ def _extract_model(model_uri, sagemaker_session, tmp): download_file_from_url(model_uri, local_model_path, sagemaker_session) else: local_model_path = model_uri.replace("file://", "") + # Validate local model path does not access sensitive system paths + abs_model_path = abspath(realpath(local_model_path)) + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if abs_model_path != "/" and abs_model_path.startswith(sensitive_path): + raise ValueError( + f"model_uri cannot access sensitive system paths. " + f"Got: {model_uri} (resolved to {abs_model_path})" + ) with tarfile.open(name=local_model_path, mode="r:gz") as t: custom_extractall_tarfile(t, tmp_model_dir) return tmp_model_dir @@ -1418,984 +1429,6 @@ def resolve_nested_dict_value_from_config( config_value = get_sagemaker_config_value(sagemaker_session, config_path) if config_value is None and default_value is None: - # if there is nothing to set, return early. And there is no need to traverse through - # the dictionary or add nested dicts to it - return dictionary - - try: - current_nested_value = get_nested_value(dictionary, nested_keys) - except ValueError as e: - logging.error("Failed to check dictionary for applying sagemaker config: %s", e) - return dictionary - - if current_nested_value is None: - # only set value if not already set - if config_value is not None: - dictionary = set_nested_value(dictionary, nested_keys, config_value) - elif default_value is not None: - dictionary = set_nested_value(dictionary, nested_keys, default_value) - - from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution - - _log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path) - - return dictionary - - -def update_list_of_dicts_with_values_from_config( - input_list, - config_key_path, - required_key_paths: List[str] = None, - union_key_paths: List[List[str]] = None, - sagemaker_session=None, -): - """Updates a list of dictionaries with missing values that are present in Config. - - In some cases, config file might introduce new parameters which requires certain other - parameters to be provided as part of the input list. Without those parameters, the underlying - service will throw an exception. This method provides the capability to specify required key - paths. - - In some other cases, config file might introduce new parameters but the service API requires - either an existing parameter or the new parameter that was supplied by config but not both - - Args: - input_list: The input list that was provided as a method parameter. - config_key_path: The Key Path in the Config file that corresponds to the input_list - parameter. - required_key_paths (List[str]): List of required key paths that should be verified in the - merged output. If a required key path is missing, we will not perform the merge for that - item. - union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify - whether exactly zero/one of the parameters exist. - For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or - neither but not both, then pass [['X1', 'X2']] - sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for - SageMaker interactions (default: None). - - Returns: - No output. In place merge happens. - """ - if not input_list: - return - inputs_copy = copy.deepcopy(input_list) - inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or [] - unmodified_inputs_from_config = copy.deepcopy(inputs_from_config) - - for i in range(min(len(input_list), len(inputs_from_config))): - dict_from_inputs = input_list[i] - dict_from_config = inputs_from_config[i] - merge_dicts(dict_from_config, dict_from_inputs) - # Check if required key paths are present in merged dict (dict_from_config) - required_key_path_check_passed = _validate_required_paths_in_a_dict( - dict_from_config, required_key_paths - ) - if not required_key_path_check_passed: - # Don't do the merge, config is introducing a new parameter which needs a - # corresponding required parameter. - continue - union_key_path_check_passed = _validate_union_key_paths_in_a_dict( - dict_from_config, union_key_paths - ) - if not union_key_path_check_passed: - # Don't do the merge, Union parameters are not obeyed. - continue - input_list[i] = dict_from_config - - from sagemaker.core.config.config_utils import _log_sagemaker_config_merge - - _log_sagemaker_config_merge( - source_value=inputs_copy, - config_value=unmodified_inputs_from_config, - merged_source_and_config_value=input_list, - config_key_path=config_key_path, - ) - - -def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool: - """Placeholder docstring""" - if not required_key_paths: - return True - for required_key_path in required_key_paths: - if get_config_value(required_key_path, source_dict) is None: - return False - return True - - -def _validate_union_key_paths_in_a_dict( - source_dict, union_key_paths: List[List[str]] = None -) -> bool: - """Placeholder docstring""" - if not union_key_paths: - return True - for union_key_path in union_key_paths: - union_parameter_present = False - for key_path in union_key_path: - if get_config_value(key_path, source_dict): - if union_parameter_present: - return False - union_parameter_present = True - return True - - -def update_nested_dictionary_with_values_from_config( - source_dict, config_key_path, sagemaker_session=None -) -> dict: - """Updates a nested dictionary with missing values that are present in Config. - - Args: - source_dict: The input nested dictionary that was provided as method parameter. - config_key_path: The Key Path in the Config file which corresponds to this - source_dict parameter. - sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for - SageMaker interactions (default: None). - - Returns: - dict: The merged nested dictionary that is updated with missing values that are present - in the Config file. - """ - inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {} - original_config_dict_value = copy.deepcopy(inferred_config_dict) - merge_dicts(inferred_config_dict, source_dict or {}) - - if original_config_dict_value == {}: - # The config value is empty. That means either - # (1) inferred_config_dict equals source_dict, or - # (2) if source_dict was None, inferred_config_dict equals {} - # We should return whatever source_dict was to be safe. Because if for example, - # a VpcConfig is set to {} instead of None, some boto calls will fail due to - # ParamValidationError (because a VpcConfig was specified but required parameters for - # the VpcConfig were missing.) - - # Don't need to print because no config value was used or defined - return source_dict - - from sagemaker.core.config.config_utils import _log_sagemaker_config_merge - - _log_sagemaker_config_merge( - source_value=source_dict, - config_value=original_config_dict_value, - merged_source_and_config_value=inferred_config_dict, - config_key_path=config_key_path, - ) - - return inferred_config_dict - - -def stringify_object(obj: Any) -> str: - """Returns string representation of object, returning only non-None fields.""" - non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None} - return f"{type(obj).__name__}: {str(non_none_atts)}" - - -def volume_size_supported(instance_type: str) -> bool: - """Returns True if SageMaker allows volume_size to be used for the instance type. - - Raises: - ValueError: If the instance type is improperly formatted. - """ - - try: - - # local mode does not support volume size - # instance type given as pipeline parameter does not support volume size - # do not change the if statement order below. - if is_pipeline_variable(instance_type) or instance_type.startswith("local"): - return False - - parts: List[str] = instance_type.split(".") - - if len(parts) == 3 and parts[0] == "ml": - parts = parts[1:] - - if len(parts) != 2: - raise ValueError(f"Failed to parse instance type '{instance_type}'") - - # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) - # + g5 or g6 or p5 does not support attaching an EBS volume. - family = parts[0] - - unsupported_families = ["g5", "g6", "p5", "trn1"] - - return "d" not in family and not any( - family.startswith(prefix) for prefix in unsupported_families - ) - except Exception as e: - raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") - - -def instance_supports_kms(instance_type: str) -> bool: - """Returns True if SageMaker allows KMS keys to be attached to the instance. - - Raises: - ValueError: If the instance type is improperly formatted. - """ - return volume_size_supported(instance_type) - - -def get_instance_type_family(instance_type: str) -> str: - """Return the family of the instance type. - - Regex matches either "ml.." or "ml_. If input is None - or there is no match, return an empty string. - """ - instance_type_family = "" - if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type) - if match is not None: - instance_type_family = match[1] - return instance_type_family - - -def create_paginator_config(max_items: int = None, page_size: int = None) -> Dict[str, int]: - """Placeholder docstring""" - return { - "MaxItems": max_items if max_items else MAX_ITEMS, - "PageSize": page_size if page_size else PAGE_SIZE, - } - - -def format_tags(tags: Tags) -> List[TagsDict]: - """Process tags to turn them into the expected format for Sagemaker.""" - if isinstance(tags, dict): - return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()] - - return tags - - -def _get_resolved_path(path): - """Return the normalized absolute path of a given path. - - abspath - returns the absolute path without resolving symlinks - realpath - resolves the symlinks and gets the actual path - normpath - normalizes paths (e.g. remove redudant separators) - and handles platform-specific differences - """ - return normpath(realpath(abspath(path))) - - -def _is_bad_path(path, base): - """Checks if the joined path (base directory + file path) is rooted under the base directory - - Ensuring that the file does not attempt to access paths - outside the expected directory structure. - - Args: - path (str): The file path. - base (str): The base directory. - - Returns: - bool: True if the path is not rooted under the base directory, False otherwise. - """ - # joinpath will ignore base if path is absolute - return not _get_resolved_path(joinpath(base, path)).startswith(base) - - -def _is_bad_link(info, base): - """Checks if the link is rooted under the base directory. - - Ensuring that the link does not attempt to access paths outside the expected directory structure - - Args: - info (tarfile.TarInfo): The tar file info. - base (str): The base directory. - - Returns: - bool: True if the link is not rooted under the base directory, False otherwise. - """ - # Links are interpreted relative to the directory containing the link - tip = _get_resolved_path(joinpath(base, dirname(info.name))) - return _is_bad_path(info.linkname, base=tip) - - -def _get_safe_members(members): - """A generator that yields members that are safe to extract. - - It filters out bad paths and bad links. - - Args: - members (list): A list of members to check. - - Yields: - tarfile.TarInfo: The tar file info. - """ - base = _get_resolved_path("") - - for file_info in members: - if _is_bad_path(file_info.name, base): - logger.error("%s is blocked (illegal path)", file_info.name) - elif file_info.issym() and _is_bad_link(file_info, base): - logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname) - elif file_info.islnk() and _is_bad_link(file_info, base): - logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname) - else: - yield file_info - - -def _validate_extracted_paths(extract_path): - """Validate that extracted paths remain within the expected directory. - - Performs post-extraction validation to ensure all extracted files and directories - are within the intended extraction path. - - Args: - extract_path (str): The path where files were extracted. - - Raises: - ValueError: If any extracted file is outside the expected extraction path. - """ - base = _get_resolved_path(extract_path) - - for root, dirs, files in os.walk(extract_path): - # Check directories - for dir_name in dirs: - dir_path = os.path.join(root, dir_name) - resolved = _get_resolved_path(dir_path) - if not resolved.startswith(base): - logger.error("Extracted directory escaped extraction path: %s", dir_path) - raise ValueError(f"Extracted path outside expected directory: {dir_path}") - - # Check files - for file_name in files: - file_path = os.path.join(root, file_name) - resolved = _get_resolved_path(file_path) - if not resolved.startswith(base): - logger.error("Extracted file escaped extraction path: %s", file_path) - raise ValueError(f"Extracted path outside expected directory: {file_path}") - - -def custom_extractall_tarfile(tar, extract_path): - """Extract a tarfile, optionally using data_filter if available. - - # TODO: The function and it's usages can be deprecated once SageMaker Python SDK - is upgraded to use Python 3.12+ - - If the tarfile has a data_filter attribute, it will be used to extract the contents of the file. - Otherwise, the _get_safe_members function will be used to filter bad paths and bad links. - - Args: - tar (tarfile.TarFile): The opened tarfile object. - extract_path (str): The path to extract the contents of the tarfile. - - Returns: - None - """ - if hasattr(tarfile, "data_filter"): - tar.extractall(path=extract_path, filter="data") - else: - tar.extractall(path=extract_path, members=_get_safe_members(tar)) - # Re-validate extracted paths to catch symlink race conditions - _validate_extracted_paths(extract_path) - - -def can_model_package_source_uri_autopopulate(source_uri: str): - """Checks if the source_uri can lead to auto-population of information in the Model registry. - - Args: - source_uri (str): The source uri. - - Returns: - bool: True if the source_uri can lead to auto-population, False otherwise. - """ - return bool( - re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri) - ) - - -def flatten_dict( - d: Dict[str, Any], - max_flatten_depth=None, -) -> Dict[str, Any]: - """Flatten a dictionary object. - - d (Dict[str, Any]): - The dict that will be flattened. - max_flatten_depth (Optional[int]): - Maximum depth to merge. - """ - - def tuple_reducer(k1, k2): - if k1 is None: - return (k2,) - return k1 + (k2,) - - # check max_flatten_depth - if max_flatten_depth is not None and max_flatten_depth < 1: - raise ValueError("max_flatten_depth should not be less than 1.") - - reducer = tuple_reducer - - flat_dict = {} - - def _flatten(_d, depth, parent=None): - key_value_iterable = viewitems(_d) - has_item = False - for key, value in key_value_iterable: - has_item = True - flat_key = reducer(parent, key) - if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): - has_child = _flatten(value, depth=depth + 1, parent=flat_key) - if has_child: - continue - - if flat_key in flat_dict: - raise ValueError("duplicated key '{}'".format(flat_key)) - flat_dict[flat_key] = value - - return has_item - - _flatten(d, depth=1) - return flat_dict - - -def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: - """Set a value to a sequence of nested keys.""" - - key = keys[0] - - if len(keys) == 1: - d[key] = value - return - - d = d.setdefault(key, {}) - nested_set_dict(d, keys[1:], value) - - -def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: - """Unflatten dict-like object. - - d (Dict[str, Any]) : - The dict that will be unflattened. - """ - - unflattened_dict = {} - for flat_key, value in viewitems(d): - key_tuple = flat_key - nested_set_dict(unflattened_dict, key_tuple, value) - - return unflattened_dict - - -def deep_override_dict( - dict1: Dict[str, Any], dict2: Dict[str, Any], skip_keys: Optional[List[str]] = None -) -> Dict[str, Any]: - """Overrides any overlapping contents of dict1 with the contents of dict2.""" - if skip_keys is None: - skip_keys = [] - - flattened_dict1 = flatten_dict(dict1) - flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None} - flattened_dict2 = flatten_dict( - {key: value for key, value in dict2.items() if key not in skip_keys} - ) - flattened_dict1.update(flattened_dict2) - return unflatten_dict(flattened_dict1) if flattened_dict1 else {} - - -def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: - """Resolve Routing Config - - Args: - routing_config (Optional[Dict[str, Any]]): The routing config. - - Returns: - Optional[Dict[str, Any]]: The resolved routing config. - - Raises: - ValueError: If the RoutingStrategy is invalid. - """ - - if routing_config: - routing_strategy = routing_config.get("RoutingStrategy", None) - if routing_strategy: - if isinstance(routing_strategy, RoutingStrategy): - return {"RoutingStrategy": routing_strategy.name} - if isinstance(routing_strategy, str) and ( - routing_strategy.upper() == RoutingStrategy.RANDOM.name - or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name - ): - return {"RoutingStrategy": routing_strategy.upper()} - raise ValueError( - "RoutingStrategy must be either RoutingStrategy.RANDOM " - "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" - ) - return None - - -@lru_cache -def get_instance_rate_per_hour( - instance_type: str, - region: str, -) -> Optional[Dict[str, str]]: - """Gets instance rate per hour for the given instance type. - - Args: - instance_type (str): The instance type. - region (str): The region. - Returns: - Optional[Dict[str, str]]: Instance rate per hour. - Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. - - Raises: - Exception: An exception is raised if - the IAM role is not authorized to perform pricing:GetProducts. - or unexpected event happened. - """ - region_name = "us-east-1" - if region.startswith("eu") or region.startswith("af"): - region_name = "eu-central-1" - elif region.startswith("ap") or region.startswith("cn"): - region_name = "ap-south-1" - - pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) - res = pricing_client.get_products( - ServiceCode="AmazonSageMaker", - Filters=[ - {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, - {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, - {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, - ], - ) - - price_list = res.get("PriceList", []) - if len(price_list) > 0: - price_data = price_list[0] - if isinstance(price_data, str): - price_data = json.loads(price_data) - - instance_rate_per_hour = extract_instance_rate_per_hour(price_data) - if instance_rate_per_hour is not None: - return instance_rate_per_hour - raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") - - -def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: - """Extract instance rate per hour for the given Price JSON data. - - Args: - price_data (Dict[str, Any]): The Price JSON data. - Returns: - Optional[Dict[str, str], None]: Instance rate per hour. - """ - - if price_data is not None: - price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() - for dimension in price_dimensions: - for price in dimension.get("priceDimensions", {}).values(): - for currency in price.get("pricePerUnit", {}).keys(): - value = price.get("pricePerUnit", {}).get(currency) - if value is not None: - value = str(round(float(value), 3)) - return { - "unit": f"{currency}/Hr", - "value": value, - "name": "On-demand Instance Rate", - } - return None - - -def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]: - """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase. - - Args: - data (dict): The dictionary to be updated. - - Returns: - dict: The updated dictionary with keys in PascalCase. - """ - result = {} - - def convert_key(key): - """Converts a snake_case key to PascalCase.""" - return "".join(part.capitalize() for part in key.split("_")) - - def convert_value(value): - """Recursively processes the value of a key-value pair.""" - if isinstance(value, dict): - return camel_case_to_pascal_case(value) - if isinstance(value, list): - return [convert_value(item) for item in value] - - return value - - for key, value in data.items(): - result[convert_key(key)] = convert_value(value) - - return result - - -def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool: - """Returns True if ``tag`` already exists. - - Args: - tag (TagsDict): The tag dictionary. - curr_tags (Optional[Tags]): The current tags. - - Returns: - bool: True if the tag exists. - """ - if curr_tags is None: - return False - - for curr_tag in curr_tags: - if tag["Key"] == curr_tag["Key"]: - return True - - return False - - -def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]: - """Validates new tags against existing tags. - - Args: - new_tags (Optional[Tags]): The new tags. - curr_tags (Optional[Tags]): The current tags. - - Returns: - Optional[Tags]: The updated tags. - """ - if curr_tags is None: - return new_tags - - if curr_tags and isinstance(curr_tags, dict): - curr_tags = [curr_tags] - - if isinstance(new_tags, dict): - if not tag_exists(new_tags, curr_tags): - curr_tags.append(new_tags) - elif isinstance(new_tags, list): - for new_tag in new_tags: - if not tag_exists(new_tag, curr_tags): - curr_tags.append(new_tag) - - return curr_tags - - -def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]: - """Remove a tag with the given key from the list of tags. - - Args: - key (str): The key of the tag to remove. - tags (Optional[Tags]): The current list of tags. - - Returns: - Optional[Tags]: The updated list of tags with the tag removed. - """ - if tags is None: - return tags - if isinstance(tags, dict): - tags = [tags] - - updated_tags = [] - for tag in tags: - if tag["Key"] != key: - updated_tags.append(tag) - - if not updated_tags: - return None - if len(updated_tags) == 1: - return updated_tags[0] - return updated_tags - - -def get_domain_for_region(region: str) -> str: - """Returns the domain for the given region. - - Args: - region (str): AWS region name. - """ - return ALTERNATE_DOMAINS.get(region, "amazonaws.com") - - -def camel_to_snake(camel_case_string: str) -> str: - """Converts camelCase to snake_case_string using a regex. - - This regex cannot handle whitespace ("camelString TwoWords") - """ - return re.sub(r"(? Dict[Any, Any]: - """Recursively walks a json object and applies a given function to the keys. - - stop_keys (Optional[list[str]]): List of field keys that should stop the application function. - Any children of these keys will not have the application function applied to them. - """ - - def _walk_and_apply_json(json_obj, new): - if isinstance(json_obj, dict) and isinstance(new, dict): - for key, value in json_obj.items(): - new_key = apply(key) - if (stop_keys and new_key not in stop_keys) or stop_keys is None: - if isinstance(value, dict): - new[new_key] = {} - _walk_and_apply_json(value, new=new[new_key]) - elif isinstance(value, list): - new[new_key] = [] - for item in value: - _walk_and_apply_json(item, new=new[new_key]) - else: - new[new_key] = value - else: - new[new_key] = value - elif isinstance(json_obj, dict) and isinstance(new, list): - new.append(_walk_and_apply_json(json_obj, new={})) - elif isinstance(json_obj, list) and isinstance(new, dict): - new.update(json_obj) - elif isinstance(json_obj, list) and isinstance(new, list): - new.append(json_obj) - elif isinstance(json_obj, str) and isinstance(new, list): - new.append(json_obj) - return new - - return _walk_and_apply_json(json_obj, new={}) - - -def _wait_until(callable_fn, poll=5): - """Placeholder docstring""" - elapsed_time = 0 - result = None - while result is None: - try: - elapsed_time += poll - time.sleep(poll) - result = callable_fn() - except botocore.exceptions.ClientError as err: - # For initial 5 mins we accept/pass AccessDeniedException. - # The reason is to await tag propagation to avoid false AccessDenied claims for an - # access policy based on resource tags, The caveat here is for true AccessDenied - # cases the routine will fail after 5 mins - if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300: - logger.warning( - "Received AccessDeniedException. This could mean the IAM role does not " - "have the resource permissions, in which case please add resource access " - "and retry. For cases where the role has tag based resource policy, " - "continuing to wait for tag propagation.." - ) - continue - raise err - return result - - -def _flush_log_streams( - stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap -): - """Placeholder docstring""" - if len(stream_names) < instance_count: - # Log streams are created whenever a container starts writing to stdout/err, so this list - # may be dynamic until we have a stream for every instance. - try: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=min(instance_count, 50), - ) - stream_names = [s["logStreamName"] for s in streams["logStreams"]] - - while "nextToken" in streams: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=50, - ) - - stream_names.extend([s["logStreamName"] for s in streams["logStreams"]]) - - positions.update( - [ - (s, sagemaker.core.logs.Position(timestamp=0, skip=0)) - for s in stream_names - if s not in positions - ] - ) - except ClientError as e: - # On the very first training job run on an account, there's no log group until - # the container starts logging, so ignore any errors thrown about that - err = e.response.get("Error", {}) - if err.get("Code", None) != "ResourceNotFoundException": - raise - - if len(stream_names) > 0: - if dot: - print("") - dot = False - for idx, event in sagemaker.core.logs.multi_stream_iter( - client, log_group, stream_names, positions - ): - color_wrap(idx, event["message"]) - ts, count = positions[stream_names[idx]] - if event["timestamp"] == ts: - positions[stream_names[idx]] = sagemaker.core.logs.Position( - timestamp=ts, skip=count + 1 - ) - else: - positions[stream_names[idx]] = sagemaker.core.logs.Position( - timestamp=event["timestamp"], skip=1 - ) - else: - dot = True - print(".", end="") - sys.stdout.flush() - - -class LogState(object): - """Placeholder docstring""" - - STARTING = 1 - WAIT_IN_PROGRESS = 2 - TAILING = 3 - JOB_COMPLETE = 4 - COMPLETE = 5 - - -_STATUS_CODE_TABLE = { - "COMPLETED": "Completed", - "INPROGRESS": "InProgress", - "IN_PROGRESS": "InProgress", - "FAILED": "Failed", - "STOPPED": "Stopped", - "STOPPING": "Stopping", - "STARTING": "Starting", - "PENDING": "Pending", -} - - -def _get_initial_job_state(description, status_key, wait): - """Placeholder docstring""" - status = description[status_key] - job_already_completed = status in ("Completed", "Failed", "Stopped") - return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE - - -def _logs_init(boto_session, description, job): - """Placeholder docstring""" - if job == "Training": - if "InstanceGroups" in description["ResourceConfig"]: - instance_count = 0 - for instanceGroup in description["ResourceConfig"]["InstanceGroups"]: - instance_count += instanceGroup["InstanceCount"] - else: - instance_count = description["ResourceConfig"]["InstanceCount"] - elif job == "Transform": - instance_count = description["TransformResources"]["InstanceCount"] - elif job == "Processing": - instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] - elif job == "AutoML": - instance_count = 0 - - stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position - - # Increase retries allowed (from default of 4), as we don't want waiting for a training job - # to be interrupted by a transient exception. - config = botocore.config.Config(retries={"max_attempts": 15}) - client = boto_session.client("logs", config=config) - log_group = "/aws/sagemaker/" + job + "Jobs" - - dot = False - - color_wrap = sagemaker.core.logs.ColorWrap() - - return instance_count, stream_names, positions, client, log_group, dot, color_wrap - - -def _check_job_status(job, desc, status_key_name): - """Check to see if the job completed successfully. - - If not, construct and raise a exceptions. (UnexpectedStatusException). - - Args: - job (str): The name of the job to check. - desc (dict[str, str]): The result of ``describe_training_job()``. - status_key_name (str): Status key name to check for. - - Raises: - exceptions.CapacityError: If the training job fails with CapacityError. - exceptions.UnexpectedStatusException: If the training job fails. - """ - status = desc[status_key_name] - # If the status is capital case, then convert it to Camel case - status = _STATUS_CODE_TABLE.get(status, status) - - if status == "Stopped": - logger.warning( - "Job ended with status 'Stopped' rather than 'Completed'. " - "This could mean the job timed out or stopped early for some other reason: " - "Consider checking whether it completed as you expect." - ) - elif status != "Completed": - reason = desc.get("FailureReason", "(No reason provided)") - job_type = status_key_name.replace("JobStatus", " job") - troubleshooting = ( - "https://docs.aws.amazon.com/sagemaker/latest/dg/" - "sagemaker-python-sdk-troubleshooting.html" - ) - message = ( - "Error for {job_type} {job_name}: {status}. Reason: {reason}. " - "Check troubleshooting guide for common errors: {troubleshooting}" - ).format( - job_type=job_type, - job_name=job, - status=status, - reason=reason, - troubleshooting=troubleshooting, - ) - if "CapacityError" in str(reason): - raise exceptions.CapacityError( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - raise exceptions.UnexpectedStatusException( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - - -def _create_resource(create_fn): - """Call create function and accepts/pass when resource already exists. - - This is a helper function to use an existing resource if found when creating. - - Args: - create_fn: Create resource function. - - Returns: - (bool): True if new resource was created, False if resource already exists. - """ - try: - create_fn() - # create function succeeded, resource does not exist already - return True - except ClientError as ce: - error_code = ce.response["Error"]["Code"] - error_message = ce.response["Error"]["Message"] - already_exists_exceptions = ["ValidationException", "ResourceInUse"] - already_exists_msg_patterns = ["Cannot create already existing", "already exists"] - if not ( - error_code in already_exists_exceptions - and any(p in error_message for p in already_exists_msg_patterns) - ): - raise ce - # no new resource created as resource already exists - return False - - -def _is_s3_uri(s3_uri: Optional[str]) -> bool: - """Checks whether an S3 URI is valid. - - Args: - s3_uri (Optional[str]): The S3 URI. - - Returns: - bool: Whether the S3 URI is valid. - """ - if s3_uri is None: - return False + # if there is nothing to set, return early. And there is no need t - return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None +... [FILE TRUNCATED: This file is 85116 characters (1063 lines) which exceeds the 50000 character limit. Only the first 50000 characters are shown. Consider using search_code to find specific functions or sections instead of reading the entire file.] diff --git a/sagemaker-core/src/sagemaker/core/utils/__init__.py b/sagemaker-core/src/sagemaker/core/utils/__init__.py index 9947387537..7ac8d6fe1f 100644 --- a/sagemaker-core/src/sagemaker/core/utils/__init__.py +++ b/sagemaker-core/src/sagemaker/core/utils/__init__.py @@ -36,6 +36,8 @@ "sagemaker_timestamp", "sagemaker_short_timestamp", "get_config_value", + "_validate_source_directory", + "_validate_dependency_path", ] diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py index 15540fcd1f..f6ca9aaa29 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py @@ -140,7 +140,11 @@ def repack(inference_script, model_archive, source_dir=None): # pragma: no cove # the data directory contains a model archive generated by a previous training job # With ModelTrainer, the input data is available in the training channel data_directory = "/opt/ml/input/data/training" - model_path = os.path.join(data_directory, model_archive.split("/")[-1]) + archive_basename = os.path.basename(model_archive) + # Validate the basename to prevent path traversal + if not archive_basename or archive_basename in (".", "..") or os.sep in archive_basename: + raise ValueError(f"Invalid model archive name: {model_archive}") + model_path = os.path.join(data_directory, archive_basename) # create a temporary directory with tempfile.TemporaryDirectory() as tmp: @@ -156,6 +160,10 @@ def repack(inference_script, model_archive, source_dir=None): # pragma: no cove with tarfile.open(name=local_path, mode="r:gz") as tf: custom_extractall_tarfile(tf, src_dir) + # Validate inference_script does not contain path traversal components + if os.sep in inference_script or inference_script.startswith(".."): + raise ValueError(f"Invalid inference script path: {inference_script}") + if source_dir: # copy /opt/ml/input/data/code to code/ (ModelTrainer structure) source_code_path = "/opt/ml/input/data/code" @@ -181,11 +189,25 @@ def repack(inference_script, model_archive, source_dir=None): # pragma: no cove entry_point = path break if entry_point: - shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) + dest_path = os.path.join(code_dir, inference_script) + # Verify destination is within code_dir + resolved_dest = _get_resolved_path(dest_path) + resolved_code = _get_resolved_path(code_dir) + if not resolved_dest.startswith(resolved_code + os.sep) and resolved_dest != resolved_code: + raise ValueError( + f"Inference script destination {dest_path} escapes code directory" + ) + shutil.copy2(entry_point, dest_path) # Note: Requirements.txt dependencies are automatically installed by ModelTrainer # before this script runs, so no additional installation is needed here. + # Verify the code directory exists and is within src_dir before copying + resolved_src = _get_resolved_path(src_dir) + resolved_code = _get_resolved_path(code_dir) + if not resolved_code.startswith(resolved_src + os.sep): + raise ValueError("Code directory is not within the source directory") + # copy the "src" dir, which includes the previous training job's model and the # custom inference script, to the output of this training job shutil.copytree(src_dir, "/opt/ml/model", dirs_exist_ok=True) diff --git a/sagemaker-mlops/tests/unit/workflow/test_repack_model.py b/sagemaker-mlops/tests/unit/workflow/test_repack_model.py index 24936594be..885e4e0865 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_repack_model.py +++ b/sagemaker-mlops/tests/unit/workflow/test_repack_model.py @@ -217,3 +217,28 @@ def test_custom_extractall_tarfile_without_data_filter(): call_args = mock_tar.extractall.call_args assert call_args[1]['path'] == extract_path assert 'members' in call_args[1] + + +def test_is_bad_path_with_dotdot_traversal(): + """Test _is_bad_path correctly identifies directory traversal with '..'.""" + base = _get_resolved_path("/tmp/safe_base") + # A path that tries to escape via .. + assert _is_bad_path("../../etc/passwd", base) is True + + +def test_is_bad_path_with_absolute_escape(): + """Test _is_bad_path correctly identifies absolute path escape.""" + base = _get_resolved_path("/tmp/safe_base") + assert _is_bad_path("/etc/passwd", base) is True + + +def test_is_bad_path_nested_safe(): + """Test _is_bad_path allows nested safe paths.""" + base = _get_resolved_path("/tmp/safe_base") + assert _is_bad_path("subdir/file.txt", base) is False + + +def test_get_safe_members_empty_list(): + """Test _get_safe_members with empty member list.""" + safe_members = list(_get_safe_members([], "/tmp/extract")) + assert len(safe_members) == 0