7676WAITING_DOT_NUMBER = 10
7777MAX_ITEMS = 100
7878PAGE_SIZE = 10
79+ _MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
80+
81+ _SENSITIVE_SYSTEM_PATHS = [
82+ abspath (os .path .expanduser ("~/.aws" )),
83+ abspath (os .path .expanduser ("~/.ssh" )),
84+ abspath (os .path .expanduser ("~/.kube" )),
85+ abspath (os .path .expanduser ("~/.docker" )),
86+ abspath (os .path .expanduser ("~/.config" )),
87+ abspath (os .path .expanduser ("~/.credentials" )),
88+ abspath (realpath ("/etc" )),
89+ abspath (realpath ("/root" )),
90+ abspath (realpath ("/var/lib" )),
91+ abspath (realpath ("/opt/ml/metadata" )),
92+ ]
7993
8094logger = logging .getLogger (__name__ )
8195
@@ -601,11 +615,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
601615 shutil .move (tmp_model_path , repacked_model_uri .replace ("file://" , "" ))
602616
603617
618+ def _validate_source_directory (source_directory ):
619+ """Validate that source_directory is safe to use.
620+
621+ Ensures the source directory path does not access restricted system locations.
622+
623+ Args:
624+ source_directory (str): The source directory path to validate.
625+
626+ Raises:
627+ ValueError: If the path is not allowed.
628+ """
629+ if not source_directory or source_directory .lower ().startswith ("s3://" ):
630+ # S3 paths and None are safe
631+ return
632+
633+ # Resolve symlinks to get the actual path
634+ abs_source = abspath (realpath (source_directory ))
635+
636+ # Check if the source path is under any sensitive directory
637+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
638+ if abs_source != "/" and abs_source .startswith (sensitive_path ):
639+ raise ValueError (
640+ f"source_directory cannot access sensitive system paths. "
641+ f"Got: { source_directory } (resolved to { abs_source } )"
642+ )
643+
644+
645+ def _validate_dependency_path (dependency ):
646+ """Validate that a dependency path is safe to use.
647+
648+ Ensures the dependency path does not access restricted system locations.
649+
650+ Args:
651+ dependency (str): The dependency path to validate.
652+
653+ Raises:
654+ ValueError: If the path is not allowed.
655+ """
656+ if not dependency :
657+ return
658+
659+ # Resolve symlinks to get the actual path
660+ abs_dependency = abspath (realpath (dependency ))
661+
662+ # Check if the dependency path is under any sensitive directory
663+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
664+ if abs_dependency != "/" and abs_dependency .startswith (sensitive_path ):
665+ raise ValueError (
666+ f"dependency path cannot access sensitive system paths. "
667+ f"Got: { dependency } (resolved to { abs_dependency } )"
668+ )
669+
670+
604671def _create_or_update_code_dir (
605672 model_dir , inference_script , source_directory , dependencies , sagemaker_session , tmp
606673):
607674 """Placeholder docstring"""
608675 code_dir = os .path .join (model_dir , "code" )
676+ resolved_code_dir = _get_resolved_path (code_dir )
677+
678+ # Validate that code_dir does not resolve to a sensitive system path
679+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
680+ if resolved_code_dir != "/" and resolved_code_dir .startswith (sensitive_path ):
681+ raise ValueError (
682+ f"Invalid code_dir path: { code_dir } resolves to sensitive system path { resolved_code_dir } "
683+ )
684+
609685 if source_directory and source_directory .lower ().startswith ("s3://" ):
610686 local_code_path = os .path .join (tmp , "local_code.tar.gz" )
611687 download_file_from_url (source_directory , local_code_path , sagemaker_session )
@@ -614,6 +690,8 @@ def _create_or_update_code_dir(
614690 custom_extractall_tarfile (t , code_dir )
615691
616692 elif source_directory :
693+ # Validate source_directory for security
694+ _validate_source_directory (source_directory )
617695 if os .path .exists (code_dir ):
618696 shutil .rmtree (code_dir )
619697 shutil .copytree (source_directory , code_dir )
@@ -646,6 +724,8 @@ def _create_or_update_code_dir(
646724 )
647725
648726 for dependency in dependencies :
727+ # Validate dependency path for security
728+ _validate_dependency_path (dependency )
649729 lib_dir = os .path .join (code_dir , "lib" )
650730 if os .path .isdir (dependency ):
651731 shutil .copytree (dependency , os .path .join (lib_dir , os .path .basename (dependency )))
@@ -1620,6 +1700,38 @@ def _get_safe_members(members):
16201700 yield file_info
16211701
16221702
1703+ def _validate_extracted_paths (extract_path ):
1704+ """Validate that extracted paths remain within the expected directory.
1705+
1706+ Performs post-extraction validation to ensure all extracted files and directories
1707+ are within the intended extraction path.
1708+
1709+ Args:
1710+ extract_path (str): The path where files were extracted.
1711+
1712+ Raises:
1713+ ValueError: If any extracted file is outside the expected extraction path.
1714+ """
1715+ base = _get_resolved_path (extract_path )
1716+
1717+ for root , dirs , files in os .walk (extract_path ):
1718+ # Check directories
1719+ for dir_name in dirs :
1720+ dir_path = os .path .join (root , dir_name )
1721+ resolved = _get_resolved_path (dir_path )
1722+ if not resolved .startswith (base ):
1723+ logger .error ("Extracted directory escaped extraction path: %s" , dir_path )
1724+ raise ValueError (f"Extracted path outside expected directory: { dir_path } " )
1725+
1726+ # Check files
1727+ for file_name in files :
1728+ file_path = os .path .join (root , file_name )
1729+ resolved = _get_resolved_path (file_path )
1730+ if not resolved .startswith (base ):
1731+ logger .error ("Extracted file escaped extraction path: %s" , file_path )
1732+ raise ValueError (f"Extracted path outside expected directory: { file_path } " )
1733+
1734+
16231735def custom_extractall_tarfile (tar , extract_path ):
16241736 """Extract a tarfile, optionally using data_filter if available.
16251737
@@ -1640,6 +1752,8 @@ def custom_extractall_tarfile(tar, extract_path):
16401752 tar .extractall (path = extract_path , filter = "data" )
16411753 else :
16421754 tar .extractall (path = extract_path , members = _get_safe_members (tar ))
1755+ # Re-validate extracted paths to catch symlink race conditions
1756+ _validate_extracted_paths (extract_path )
16431757
16441758
16451759def can_model_package_source_uri_autopopulate (source_uri : str ):
0 commit comments