Skip to content

Commit 4e5fda9

Browse files
committed
Use network flag
1 parent c0134f5 commit 4e5fda9

3 files changed

Lines changed: 39 additions & 22 deletions

File tree

sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,13 @@
5757
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
5858

5959

60-
def _rmtree(path, image=None):
60+
def _rmtree(path, image=None, is_studio=False):
6161
"""Remove a directory tree, handling root-owned files from Docker containers."""
6262
try:
6363
shutil.rmtree(path)
6464
except PermissionError:
6565
# Files created by Docker containers are owned by root.
6666
# Use a Docker container to remove them since os.chmod will also fail.
67-
# Use the training image (already pulled) to avoid needing alpine.
6867
if image is None:
6968
logger.warning(
7069
"Failed to clean up root-owned files in %s. "
@@ -73,11 +72,11 @@ def _rmtree(path, image=None):
7372
)
7473
raise
7574
try:
76-
subprocess.run(
77-
["docker", "run", "--rm", "-v", f"{path}:/delete", image, "rm", "-rf", "/delete"],
78-
check=True,
79-
capture_output=True,
80-
)
75+
cmd = ["docker", "run", "--rm"]
76+
if is_studio:
77+
cmd += ["--network", "sagemaker"]
78+
cmd += ["-v", f"{path}:/delete", image, "rm", "-rf", "/delete"]
79+
subprocess.run(cmd, check=True, capture_output=True)
8180
# The mount point directory itself may remain — clean it up
8281
if os.path.exists(path):
8382
shutil.rmtree(path, ignore_errors=True)
@@ -242,12 +241,12 @@ def train(
242241
# Print our Job Complete line
243242
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
244243

245-
_rmtree(os.path.join(self.container_root, "input"), self.image)
246-
_rmtree(os.path.join(self.container_root, "shared"), self.image)
244+
_rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio)
245+
_rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio)
247246
for host in self.hosts:
248-
_rmtree(os.path.join(self.container_root, host), self.image)
247+
_rmtree(os.path.join(self.container_root, host), self.image, self.is_studio)
249248
for folder in self._temporary_folders:
250-
_rmtree(os.path.join(self.container_root, folder), self.image)
249+
_rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio)
251250
return artifacts
252251

253252
def retrieve_artifacts(

sagemaker-train/src/sagemaker/train/local/local_container.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,13 @@
6565
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
6666

6767

68-
def _rmtree(path, image=None):
68+
def _rmtree(path, image=None, is_studio=False):
6969
"""Remove a directory tree, handling root-owned files from Docker containers."""
7070
try:
7171
shutil.rmtree(path)
7272
except PermissionError:
7373
# Files created by Docker containers are owned by root.
7474
# Use a Docker container to remove them since os.chmod will also fail.
75-
# Use the training image (already pulled) to avoid needing alpine.
7675
if image is None:
7776
logger.warning(
7877
"Failed to clean up root-owned files in %s. "
@@ -81,11 +80,11 @@ def _rmtree(path, image=None):
8180
)
8281
raise
8382
try:
84-
subprocess.run(
85-
["docker", "run", "--rm", "-v", f"{path}:/delete", image, "rm", "-rf", "/delete"],
86-
check=True,
87-
capture_output=True,
88-
)
83+
cmd = ["docker", "run", "--rm"]
84+
if is_studio:
85+
cmd += ["--network", "sagemaker"]
86+
cmd += ["-v", f"{path}:/delete", image, "rm", "-rf", "/delete"]
87+
subprocess.run(cmd, check=True, capture_output=True)
8988
# The mount point directory itself may remain — clean it up
9089
if os.path.exists(path):
9190
shutil.rmtree(path, ignore_errors=True)
@@ -250,12 +249,12 @@ def train(
250249
# Print our Job Complete line
251250
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
252251

253-
_rmtree(os.path.join(self.container_root, "input"), self.image)
254-
_rmtree(os.path.join(self.container_root, "shared"), self.image)
252+
_rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio)
253+
_rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio)
255254
for host in self.hosts:
256-
_rmtree(os.path.join(self.container_root, host), self.image)
255+
_rmtree(os.path.join(self.container_root, host), self.image, self.is_studio)
257256
for folder in self._temporary_folders:
258-
_rmtree(os.path.join(self.container_root, folder), self.image)
257+
_rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio)
259258
return artifacts
260259

261260
def retrieve_artifacts(

sagemaker-train/tests/unit/train/local/test_local_container.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ def test_rmtree_permission_error_docker_fallback(self, mock_exists, mock_run, mo
4242
capture_output=True,
4343
)
4444

45+
@patch("sagemaker.train.local.local_container.shutil.rmtree")
46+
@patch("sagemaker.train.local.local_container.subprocess.run")
47+
@patch("sagemaker.train.local.local_container.os.path.exists", return_value=False)
48+
def test_rmtree_studio_adds_network(self, mock_exists, mock_run, mock_rmtree):
49+
"""In Studio, docker run includes --network sagemaker."""
50+
mock_rmtree.side_effect = PermissionError("Permission denied")
51+
52+
_rmtree("/tmp/test", IMAGE, is_studio=True)
53+
54+
mock_run.assert_called_once_with(
55+
[
56+
"docker", "run", "--rm",
57+
"--network", "sagemaker",
58+
"-v", "/tmp/test:/delete", IMAGE, "rm", "-rf", "/delete",
59+
],
60+
check=True,
61+
capture_output=True,
62+
)
63+
4564
@patch("sagemaker.train.local.local_container.shutil.rmtree")
4665
@patch("sagemaker.train.local.local_container.subprocess.run")
4766
@patch("sagemaker.train.local.local_container.os.path.exists", return_value=True)

0 commit comments

Comments
 (0)