|
83 | 83 | SM_CODE_CONTAINER_PATH, |
84 | 84 | SM_DRIVERS, |
85 | 85 | SM_DRIVERS_LOCAL_PATH, |
| 86 | + SM_DEPENDENCIES, |
| 87 | + SM_DEPENDENCIES_CONTAINER_PATH, |
86 | 88 | SM_RECIPE, |
87 | 89 | SM_RECIPE_YAML, |
88 | 90 | SM_RECIPE_CONTAINER_PATH, |
|
99 | 101 | EXECUTE_BASIC_SCRIPT_DRIVER, |
100 | 102 | INSTALL_AUTO_REQUIREMENTS, |
101 | 103 | INSTALL_REQUIREMENTS, |
| 104 | + INSTALL_DEPENDENCIES, |
102 | 105 | ) |
103 | 106 | from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter |
104 | 107 | from sagemaker.core.telemetry.constants import Feature |
@@ -484,6 +487,13 @@ def _validate_source_code(self, source_code: Optional[SourceCode]): |
484 | 487 | f"Invalid 'entry_script': {entry_script}. " |
485 | 488 | "Must be a valid file within the 'source_dir'.", |
486 | 489 | ) |
| 490 | + if source_code.dependencies: |
| 491 | + for dep_path in source_code.dependencies: |
| 492 | + if not _is_valid_path(dep_path): |
| 493 | + raise ValueError( |
| 494 | + f"Invalid dependency path: {dep_path}. " |
| 495 | + "Each dependency must be a valid local directory or file path." |
| 496 | + ) |
487 | 497 |
|
488 | 498 | @staticmethod |
489 | 499 | def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str): |
@@ -654,6 +664,24 @@ def _create_training_job_args( |
654 | 664 | ) |
655 | 665 | final_input_data_config.append(source_code_channel) |
656 | 666 |
|
| 667 | + # If dependencies are provided, create a channel for the dependencies |
| 668 | + # The dependencies will be mounted at /opt/ml/input/data/sm_dependencies |
| 669 | + if self.source_code.dependencies: |
| 670 | + deps_tmp_dir = TemporaryDirectory() |
| 671 | + for dep_path in self.source_code.dependencies: |
| 672 | + dep_basename = os.path.basename(os.path.normpath(dep_path)) |
| 673 | + dest_path = os.path.join(deps_tmp_dir.name, dep_basename) |
| 674 | + if os.path.isdir(dep_path): |
| 675 | + shutil.copytree(dep_path, dest_path, dirs_exist_ok=True) |
| 676 | + else: |
| 677 | + shutil.copy2(dep_path, dest_path) |
| 678 | + dependencies_channel = self.create_input_data_channel( |
| 679 | + channel_name=SM_DEPENDENCIES, |
| 680 | + data_source=deps_tmp_dir.name, |
| 681 | + key_prefix=input_data_key_prefix, |
| 682 | + ) |
| 683 | + final_input_data_config.append(dependencies_channel) |
| 684 | + |
657 | 685 | self._prepare_train_script( |
658 | 686 | tmp_dir=self._temp_code_dir, |
659 | 687 | source_code=self.source_code, |
@@ -1010,6 +1038,10 @@ def _prepare_train_script( |
1010 | 1038 | base_command = source_code.command.split() |
1011 | 1039 | base_command = " ".join(base_command) |
1012 | 1040 |
|
| 1041 | + install_dependencies = "" |
| 1042 | + if source_code.dependencies: |
| 1043 | + install_dependencies = INSTALL_DEPENDENCIES |
| 1044 | + |
1013 | 1045 | install_requirements = "" |
1014 | 1046 | if source_code.requirements: |
1015 | 1047 | if self._jumpstart_config and source_code.requirements == "auto": |
@@ -1049,6 +1081,7 @@ def _prepare_train_script( |
1049 | 1081 |
|
1050 | 1082 | train_script = TRAIN_SCRIPT_TEMPLATE.format( |
1051 | 1083 | working_dir=working_dir, |
| 1084 | + install_dependencies=install_dependencies, |
1052 | 1085 | install_requirements=install_requirements, |
1053 | 1086 | execute_driver=execute_driver, |
1054 | 1087 | ) |
|
0 commit comments