Skip to content

Commit 4deb5ef

Browse files
committed
fix: address review comments (iteration #1)
1 parent 19a3c1d commit 4deb5ef

File tree

5 files changed

+91
-28
lines changed

5 files changed

+91
-28
lines changed

sagemaker-core/src/sagemaker/core/training/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class SourceCode(BaseConfig):
129129
]
130130
dependencies: Optional[List[str]] = None
131131

132+
132133
class OutputDataConfig(shapes.OutputDataConfig):
133134
"""OutputDataConfig.
134135

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@
6868

6969
SM_RECIPE = "recipe"
7070
SM_RECIPE_YAML = "recipe.yaml"
71-
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"
71+
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class ModelTrainer(BaseModel):
272272

273273
# Private Attributes for AWS_Batch
274274
_temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
275+
_temp_deps_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
275276

276277
CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [
277278
"role",
@@ -411,6 +412,8 @@ def __del__(self):
411412
self._temp_recipe_train_dir.cleanup()
412413
if self._temp_code_dir is not None:
413414
self._temp_code_dir.cleanup()
415+
if self._temp_deps_dir is not None:
416+
self._temp_deps_dir.cleanup()
414417

415418
def _validate_training_image_and_algorithm_name(
416419
self, training_image: Optional[str], algorithm_name: Optional[str]
@@ -489,10 +492,10 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
489492
)
490493
if source_code.dependencies:
491494
for dep_path in source_code.dependencies:
492-
if not _is_valid_path(dep_path):
495+
if not _is_valid_path(dep_path, path_type="Directory"):
493496
raise ValueError(
494497
f"Invalid dependency path: {dep_path}. "
495-
"Each dependency must be a valid local directory or file path."
498+
"Each dependency must be a valid local directory path."
496499
)
497500

498501
@staticmethod
@@ -667,17 +670,14 @@ def _create_training_job_args(
667670
# If dependencies are provided, create a channel for the dependencies
668671
# The dependencies will be mounted at /opt/ml/input/data/sm_dependencies
669672
if self.source_code.dependencies:
670-
deps_tmp_dir = TemporaryDirectory()
673+
self._temp_deps_dir = TemporaryDirectory()
671674
for dep_path in self.source_code.dependencies:
672675
dep_basename = os.path.basename(os.path.normpath(dep_path))
673-
dest_path = os.path.join(deps_tmp_dir.name, dep_basename)
674-
if os.path.isdir(dep_path):
675-
shutil.copytree(dep_path, dest_path, dirs_exist_ok=True)
676-
else:
677-
shutil.copy2(dep_path, dest_path)
676+
dest_path = os.path.join(self._temp_deps_dir.name, dep_basename)
677+
shutil.copytree(dep_path, dest_path, dirs_exist_ok=True)
678678
dependencies_channel = self.create_input_data_channel(
679679
channel_name=SM_DEPENDENCIES,
680-
data_source=deps_tmp_dir.name,
680+
data_source=self._temp_deps_dir.name,
681681
key_prefix=input_data_key_prefix,
682682
)
683683
final_input_data_config.append(dependencies_channel)
@@ -841,6 +841,9 @@ def train(
841841
local_container.train(wait)
842842
if self._temp_code_dir is not None:
843843
self._temp_code_dir.cleanup()
844+
if self._temp_deps_dir is not None:
845+
self._temp_deps_dir.cleanup()
846+
self._temp_deps_dir = None
844847

845848

846849
def create_input_data_channel(

sagemaker-train/src/sagemaker/train/templates.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,23 @@
4242
INSTALL_DEPENDENCIES = """
4343
echo "Setting up additional dependencies"
4444
if [ -d /opt/ml/input/data/sm_dependencies ]; then
45-
for dep_dir in /opt/ml/input/data/sm_dependencies/*/; do
46-
if [ -d "$dep_dir" ]; then
47-
echo "Adding $dep_dir to PYTHONPATH"
48-
export PYTHONPATH="$dep_dir:$PYTHONPATH"
45+
for dep in /opt/ml/input/data/sm_dependencies/*; do
46+
if [ -d "$dep" ]; then
47+
echo "Adding directory $dep to PYTHONPATH"
48+
export PYTHONPATH="$dep:$PYTHONPATH"
49+
elif [ -f "$dep" ]; then
50+
case "$dep" in
51+
*.whl|*.tar.gz)
52+
echo "Installing package $dep via pip"
53+
$SM_PIP_CMD install "$dep"
54+
;;
55+
*)
56+
echo "Adding parent directory of $dep to PYTHONPATH"
57+
export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH"
58+
;;
59+
esac
4960
fi
5061
done
51-
# Also add the root dependencies dir in case of single files
52-
export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH"
5362
fi
5463
"""
5564

sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Tests for ModelTrainer dependencies feature."""
14-
from __future__ import absolute_import
14+
from __future__ import annotations
1515

1616
import os
1717
import shutil
@@ -53,7 +53,7 @@
5353
)
5454

5555

56-
@pytest.fixture(scope="module", autouse=True)
56+
@pytest.fixture(autouse=True)
5757
def modules_session():
5858
with patch("sagemaker.train.Session", spec=Session) as session_mock:
5959
session_instance = session_mock.return_value
@@ -107,6 +107,27 @@ def test_validate_source_code_with_invalid_dependency_path_raises_value_error(mo
107107
)
108108

109109

110+
def test_validate_source_code_with_file_dependency_raises_value_error(modules_session):
111+
"""Verify _validate_source_code raises ValueError when a dependency is a file, not a directory."""
112+
dep_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False)
113+
dep_file.close()
114+
try:
115+
with pytest.raises(ValueError, match="Invalid dependency path"):
116+
ModelTrainer(
117+
training_image=DEFAULT_IMAGE,
118+
role=DEFAULT_ROLE,
119+
sagemaker_session=modules_session,
120+
compute=DEFAULT_COMPUTE_CONFIG,
121+
source_code=SourceCode(
122+
source_dir=DEFAULT_SOURCE_DIR,
123+
entry_script="custom_script.py",
124+
dependencies=[dep_file.name],
125+
),
126+
)
127+
finally:
128+
os.unlink(dep_file.name)
129+
130+
110131
def test_source_code_dependencies_validation_with_valid_dirs(modules_session):
111132
"""Create SourceCode with dependencies pointing to multiple valid directories."""
112133
dep_dir1 = tempfile.mkdtemp()
@@ -250,23 +271,39 @@ def test_train_with_dependencies_generates_pythonpath_setup_in_train_script(
250271
dependencies=[dep_dir],
251272
),
252273
)
253-
trainer.train()
254-
255-
# Check the generated train script in the temp directory
256-
assert trainer._temp_code_dir is not None or True # temp dir may be cleaned up
257-
# The key assertion is that the template was used - check via the training job call
258-
mock_training_job.create.assert_called_once()
274+
# Call _create_training_job_args to generate the train script without cleanup
275+
trainer._create_training_job_args()
276+
277+
# Read the generated train script and verify it contains PYTHONPATH setup
278+
assert trainer._temp_code_dir is not None
279+
train_script_path = os.path.join(trainer._temp_code_dir.name, "sm_train.sh")
280+
with open(train_script_path) as f:
281+
script_content = f.read()
282+
assert "sm_dependencies" in script_content
283+
assert "PYTHONPATH" in script_content
284+
assert "Setting up additional dependencies" in script_content
259285
finally:
286+
if trainer._temp_code_dir is not None:
287+
trainer._temp_code_dir.cleanup()
288+
if trainer._temp_deps_dir is not None:
289+
trainer._temp_deps_dir.cleanup()
260290
shutil.rmtree(dep_dir)
261291

262292

263-
def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session):
293+
@patch("sagemaker.train.model_trainer.TrainingJob")
294+
def test_dependencies_copied_to_temp_dir_preserving_basenames(
295+
mock_training_job, modules_session
296+
):
264297
"""Verify that each dependency directory's basename is preserved when copied."""
265298
dep_dir = tempfile.mkdtemp(suffix="_mylib")
266299
sub_file = os.path.join(dep_dir, "module.py")
267300
with open(sub_file, "w") as f:
268301
f.write("# test module")
269302

303+
modules_session.upload_data.return_value = (
304+
f"s3://{DEFAULT_BUCKET}/prefix/sm_dependencies"
305+
)
306+
270307
try:
271308
trainer = ModelTrainer(
272309
training_image=DEFAULT_IMAGE,
@@ -281,9 +318,22 @@ def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session):
281318
dependencies=[dep_dir],
282319
),
283320
)
284-
# Verify the source_code has the dependencies set
285-
assert trainer.source_code.dependencies == [dep_dir]
321+
# Call _create_training_job_args to trigger the copy
322+
trainer._create_training_job_args()
323+
324+
# Verify the dependencies were copied preserving basenames
325+
assert trainer._temp_deps_dir is not None
286326
dep_basename = os.path.basename(os.path.normpath(dep_dir))
287-
assert dep_basename.endswith("_mylib")
327+
copied_dep_path = os.path.join(trainer._temp_deps_dir.name, dep_basename)
328+
assert os.path.isdir(copied_dep_path), (
329+
f"Expected dependency directory {copied_dep_path} to exist"
330+
)
331+
assert os.path.isfile(os.path.join(copied_dep_path, "module.py")), (
332+
"Expected module.py to be copied into the dependency directory"
333+
)
288334
finally:
335+
if trainer._temp_code_dir is not None:
336+
trainer._temp_code_dir.cleanup()
337+
if trainer._temp_deps_dir is not None:
338+
trainer._temp_deps_dir.cleanup()
289339
shutil.rmtree(dep_dir)

0 commit comments

Comments
 (0)