Skip to content

Commit 19a3c1d

Browse files
committed
fix: Add additional dependencies for ModelTrainer (5668)
1 parent 272fdbf commit 19a3c1d

File tree

5 files changed

+345
-0
lines changed

5 files changed

+345
-0
lines changed

sagemaker-core/src/sagemaker/core/training/configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ class SourceCode(BaseConfig):
109109
ignore_patterns: (Optional[List[str]]) :
110110
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
111111
default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
112+
dependencies (Optional[List[str]]):
113+
A list of paths to local directories (absolute or relative) containing additional
114+
libraries that will be copied into the training container and added to PYTHONPATH.
115+
Each path must be a valid local directory or file.
112116
"""
113117

114118
source_dir: Optional[StrPipeVar] = None
@@ -123,6 +127,7 @@ class SourceCode(BaseConfig):
123127
".cache",
124128
".ipynb_checkpoints",
125129
]
130+
dependencies: Optional[List[str]] = None
126131

127132
class OutputDataConfig(shapes.OutputDataConfig):
128133
"""OutputDataConfig.

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
"amazon.nova-pro-v1:0": ["us-east-1"]
6464
}
6565

66+
SM_DEPENDENCIES = "sm_dependencies"
67+
SM_DEPENDENCIES_CONTAINER_PATH = "/opt/ml/input/data/sm_dependencies"
68+
6669
SM_RECIPE = "recipe"
6770
SM_RECIPE_YAML = "recipe.yaml"
6871
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
SM_CODE_CONTAINER_PATH,
8484
SM_DRIVERS,
8585
SM_DRIVERS_LOCAL_PATH,
86+
SM_DEPENDENCIES,
87+
SM_DEPENDENCIES_CONTAINER_PATH,
8688
SM_RECIPE,
8789
SM_RECIPE_YAML,
8890
SM_RECIPE_CONTAINER_PATH,
@@ -99,6 +101,7 @@
99101
EXECUTE_BASIC_SCRIPT_DRIVER,
100102
INSTALL_AUTO_REQUIREMENTS,
101103
INSTALL_REQUIREMENTS,
104+
INSTALL_DEPENDENCIES,
102105
)
103106
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
104107
from sagemaker.core.telemetry.constants import Feature
@@ -484,6 +487,13 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
484487
f"Invalid 'entry_script': {entry_script}. "
485488
"Must be a valid file within the 'source_dir'.",
486489
)
490+
if source_code.dependencies:
491+
for dep_path in source_code.dependencies:
492+
if not _is_valid_path(dep_path):
493+
raise ValueError(
494+
f"Invalid dependency path: {dep_path}. "
495+
"Each dependency must be a valid local directory or file path."
496+
)
487497

488498
@staticmethod
489499
def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str):
@@ -654,6 +664,24 @@ def _create_training_job_args(
654664
)
655665
final_input_data_config.append(source_code_channel)
656666

667+
# If dependencies are provided, create a channel for the dependencies
668+
# The dependencies will be mounted at /opt/ml/input/data/sm_dependencies
669+
if self.source_code.dependencies:
670+
deps_tmp_dir = TemporaryDirectory()
671+
for dep_path in self.source_code.dependencies:
672+
dep_basename = os.path.basename(os.path.normpath(dep_path))
673+
dest_path = os.path.join(deps_tmp_dir.name, dep_basename)
674+
if os.path.isdir(dep_path):
675+
shutil.copytree(dep_path, dest_path, dirs_exist_ok=True)
676+
else:
677+
shutil.copy2(dep_path, dest_path)
678+
dependencies_channel = self.create_input_data_channel(
679+
channel_name=SM_DEPENDENCIES,
680+
data_source=deps_tmp_dir.name,
681+
key_prefix=input_data_key_prefix,
682+
)
683+
final_input_data_config.append(dependencies_channel)
684+
657685
self._prepare_train_script(
658686
tmp_dir=self._temp_code_dir,
659687
source_code=self.source_code,
@@ -1010,6 +1038,10 @@ def _prepare_train_script(
10101038
base_command = source_code.command.split()
10111039
base_command = " ".join(base_command)
10121040

1041+
install_dependencies = ""
1042+
if source_code.dependencies:
1043+
install_dependencies = INSTALL_DEPENDENCIES
1044+
10131045
install_requirements = ""
10141046
if source_code.requirements:
10151047
if self._jumpstart_config and source_code.requirements == "auto":
@@ -1049,6 +1081,7 @@ def _prepare_train_script(
10491081

10501082
train_script = TRAIN_SCRIPT_TEMPLATE.format(
10511083
working_dir=working_dir,
1084+
install_dependencies=install_dependencies,
10521085
install_requirements=install_requirements,
10531086
execute_driver=execute_driver,
10541087
)

sagemaker-train/src/sagemaker/train/templates.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@
3939
$SM_PIP_CMD install -r {requirements_file}
4040
"""
4141

42+
INSTALL_DEPENDENCIES = """
43+
echo "Setting up additional dependencies"
44+
if [ -d /opt/ml/input/data/sm_dependencies ]; then
45+
for dep_dir in /opt/ml/input/data/sm_dependencies/*/; do
46+
if [ -d "$dep_dir" ]; then
47+
echo "Adding $dep_dir to PYTHONPATH"
48+
export PYTHONPATH="$dep_dir:$PYTHONPATH"
49+
fi
50+
done
51+
# Also add the root dependencies dir in case of single files
52+
export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH"
53+
fi
54+
"""
55+
4256
EXEUCTE_DISTRIBUTED_DRIVER = """
4357
echo "Running {driver_name} Driver"
4458
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script}
@@ -95,6 +109,7 @@
95109
set -x
96110
97111
{working_dir}
112+
{install_dependencies}
98113
{install_requirements}
99114
{execute_driver}
100115

0 commit comments

Comments
 (0)