Skip to content

Commit c1d8d1a

Browse files
committed
Bug fix 3 and 4
1 parent 78cff41 commit c1d8d1a

File tree

3 files changed

+202
-8
lines changed

3 files changed

+202
-8
lines changed

sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import re
1919
import shutil
20+
import stat
2021
import subprocess
2122
from tempfile import TemporaryDirectory
2223
from typing import Any, Dict, List, Optional
@@ -57,6 +58,17 @@
5758
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
5859

5960

61+
def _rmtree(path):
62+
"""Remove a directory tree, handling root-owned files from Docker containers."""
63+
def _onerror(func, path, exc_info):
64+
if isinstance(exc_info[1], PermissionError):
65+
os.chmod(path, stat.S_IRWXU)
66+
func(path)
67+
else:
68+
raise exc_info[1]
69+
shutil.rmtree(path, onerror=_onerror)
70+
71+
6072
class _LocalContainer(BaseModel):
6173
"""A local training job class for local mode model trainer.
6274
@@ -209,12 +221,12 @@ def train(
209221
# Print our Job Complete line
210222
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
211223

212-
shutil.rmtree(os.path.join(self.container_root, "input"))
213-
shutil.rmtree(os.path.join(self.container_root, "shared"))
224+
_rmtree(os.path.join(self.container_root, "input"))
225+
_rmtree(os.path.join(self.container_root, "shared"))
214226
for host in self.hosts:
215-
shutil.rmtree(os.path.join(self.container_root, host))
227+
_rmtree(os.path.join(self.container_root, host))
216228
for folder in self._temporary_folders:
217-
shutil.rmtree(os.path.join(self.container_root, folder))
229+
_rmtree(os.path.join(self.container_root, folder))
218230
return artifacts
219231

220232
def retrieve_artifacts(

sagemaker-train/src/sagemaker/train/local/local_container.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import re
2020
import shutil
21+
import stat
2122
import subprocess
2223
from tempfile import TemporaryDirectory
2324
from typing import Any, Dict, List, Optional
@@ -65,6 +66,17 @@
6566
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
6667

6768

69+
def _rmtree(path):
70+
"""Remove a directory tree, handling root-owned files from Docker containers."""
71+
def _onerror(func, path, exc_info):
72+
if isinstance(exc_info[1], PermissionError):
73+
os.chmod(path, stat.S_IRWXU)
74+
func(path)
75+
else:
76+
raise exc_info[1]
77+
shutil.rmtree(path, onerror=_onerror)
78+
79+
6880
class _LocalContainer(BaseModel):
6981
"""A local training job class for local mode model trainer.
7082
@@ -217,12 +229,12 @@ def train(
217229
# Print our Job Complete line
218230
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
219231

220-
shutil.rmtree(os.path.join(self.container_root, "input"))
221-
shutil.rmtree(os.path.join(self.container_root, "shared"))
232+
_rmtree(os.path.join(self.container_root, "input"))
233+
_rmtree(os.path.join(self.container_root, "shared"))
222234
for host in self.hosts:
223-
shutil.rmtree(os.path.join(self.container_root, host))
235+
_rmtree(os.path.join(self.container_root, host))
224236
for folder in self._temporary_folders:
225-
shutil.rmtree(os.path.join(self.container_root, folder))
237+
_rmtree(os.path.join(self.container_root, folder))
226238
return artifacts
227239

228240
def retrieve_artifacts(

sagemaker-train/tests/unit/train/test_defaults.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DEFAULT_MAX_RUNTIME_IN_SECONDS,
2626
)
2727
from sagemaker.train.configs import Compute, StoppingCondition
28+
from sagemaker.core.shapes import InstanceGroup
2829

2930

3031
class TestDefaultConstants:
@@ -435,3 +436,172 @@ def test_uses_default_volume_size_when_not_in_document(
435436
)
436437

437438
assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE
439+
440+
441+
def test_does_not_set_instance_type_when_instance_groups_configured(self):
442+
"""Test instance_type is not overwritten when instance_groups are set."""
443+
compute = Compute(
444+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
445+
instance_type=None,
446+
instance_count=None,
447+
volume_size_in_gb=30,
448+
)
449+
result = TrainDefaults.get_compute(compute=compute)
450+
assert result.instance_type is None
451+
452+
def test_does_not_set_instance_count_when_instance_groups_configured(self):
453+
"""Test instance_count is not overwritten when instance_groups are set."""
454+
compute = Compute(
455+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
456+
instance_type=None,
457+
instance_count=None,
458+
volume_size_in_gb=30,
459+
)
460+
result = TrainDefaults.get_compute(compute=compute)
461+
assert result.instance_count is None
462+
463+
def test_sets_volume_size_when_instance_groups_configured(self):
464+
"""Test volume_size_in_gb is still set when instance_groups are configured."""
465+
compute = Compute(
466+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
467+
instance_type=None,
468+
instance_count=None,
469+
volume_size_in_gb=None,
470+
)
471+
result = TrainDefaults.get_compute(compute=compute)
472+
assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE
473+
474+
def test_preserves_existing_volume_size_with_instance_groups(self):
475+
"""Test existing volume_size_in_gb is preserved when instance_groups are configured."""
476+
compute = Compute(
477+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
478+
instance_type=None,
479+
instance_count=None,
480+
volume_size_in_gb=100,
481+
)
482+
result = TrainDefaults.get_compute(compute=compute)
483+
assert result.volume_size_in_gb == 100
484+
485+
486+
class TestJumpStartTrainDefaultsGetComputeHeterogeneousCluster:
487+
"""Test JumpStartTrainDefaults.get_compute with heterogeneous cluster (instance_groups)."""
488+
489+
@patch("sagemaker.train.defaults.get_hub_content_and_document")
490+
@patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session")
491+
def test_does_not_set_instance_type_when_instance_groups_configured(
492+
self, mock_get_session, mock_get_hub_content
493+
):
494+
"""Test instance_type is not overwritten when instance_groups are set."""
495+
mock_session = MagicMock()
496+
mock_get_session.return_value = mock_session
497+
498+
mock_document = MagicMock()
499+
mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge"
500+
mock_document.TrainingVolumeSize = 100
501+
mock_get_hub_content.return_value = (None, mock_document)
502+
503+
mock_config = MagicMock()
504+
mock_config.training_config_name = None
505+
506+
compute = Compute(
507+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
508+
instance_type=None,
509+
instance_count=None,
510+
volume_size_in_gb=30,
511+
)
512+
result = JumpStartTrainDefaults.get_compute(
513+
jumpstart_config=mock_config,
514+
compute=compute,
515+
sagemaker_session=mock_session,
516+
)
517+
assert result.instance_type is None
518+
519+
@patch("sagemaker.train.defaults.get_hub_content_and_document")
520+
@patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session")
521+
def test_does_not_set_instance_count_when_instance_groups_configured(
522+
self, mock_get_session, mock_get_hub_content
523+
):
524+
"""Test instance_count is not overwritten when instance_groups are set."""
525+
mock_session = MagicMock()
526+
mock_get_session.return_value = mock_session
527+
528+
mock_document = MagicMock()
529+
mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge"
530+
mock_document.TrainingVolumeSize = 100
531+
mock_get_hub_content.return_value = (None, mock_document)
532+
533+
mock_config = MagicMock()
534+
mock_config.training_config_name = None
535+
536+
compute = Compute(
537+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
538+
instance_type=None,
539+
instance_count=None,
540+
volume_size_in_gb=30,
541+
)
542+
result = JumpStartTrainDefaults.get_compute(
543+
jumpstart_config=mock_config,
544+
compute=compute,
545+
sagemaker_session=mock_session,
546+
)
547+
assert result.instance_count is None
548+
549+
@patch("sagemaker.train.defaults.get_hub_content_and_document")
550+
@patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session")
551+
def test_sets_volume_size_from_document_when_instance_groups_configured(
552+
self, mock_get_session, mock_get_hub_content
553+
):
554+
"""Test volume_size_in_gb is set from document even when instance_groups are configured."""
555+
mock_session = MagicMock()
556+
mock_get_session.return_value = mock_session
557+
558+
mock_document = MagicMock()
559+
mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge"
560+
mock_document.TrainingVolumeSize = 100
561+
mock_get_hub_content.return_value = (None, mock_document)
562+
563+
mock_config = MagicMock()
564+
mock_config.training_config_name = None
565+
566+
compute = Compute(
567+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
568+
instance_type=None,
569+
instance_count=None,
570+
volume_size_in_gb=None,
571+
)
572+
result = JumpStartTrainDefaults.get_compute(
573+
jumpstart_config=mock_config,
574+
compute=compute,
575+
sagemaker_session=mock_session,
576+
)
577+
assert result.volume_size_in_gb == 100
578+
579+
@patch("sagemaker.train.defaults.get_hub_content_and_document")
580+
@patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session")
581+
def test_sets_default_volume_size_when_instance_groups_and_no_document_volume(
582+
self, mock_get_session, mock_get_hub_content
583+
):
584+
"""Test DEFAULT_VOLUME_SIZE is used when instance_groups set and document has no volume."""
585+
mock_session = MagicMock()
586+
mock_get_session.return_value = mock_session
587+
588+
mock_document = MagicMock()
589+
mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge"
590+
mock_document.TrainingVolumeSize = None
591+
mock_get_hub_content.return_value = (None, mock_document)
592+
593+
mock_config = MagicMock()
594+
mock_config.training_config_name = None
595+
596+
compute = Compute(
597+
instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")],
598+
instance_type=None,
599+
instance_count=None,
600+
volume_size_in_gb=None,
601+
)
602+
result = JumpStartTrainDefaults.get_compute(
603+
jumpstart_config=mock_config,
604+
compute=compute,
605+
sagemaker_session=mock_session,
606+
)
607+
assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE

0 commit comments

Comments
 (0)