Skip to content

Commit 55a4ee5

Browse files
authored
V3 Bug Fixes (#5601)
* V3 Bug Fixes * fix(model_builder): Only set s3_upload_path for S3 URIs in passthrough In _build_for_passthrough(), model_path could be a local /tmp path. Setting s3_upload_path to a local path caused CreateModel API to reject the modelDataUrl with a validation error since it requires s3:// or https:// URIs. Now only S3 URIs are assigned to s3_upload_path; local paths are handled separately by _prepare_for_mode() in LOCAL_CONTAINER mode. * Test fixes * Bug fix 3 and 4
1 parent 3aaf364 commit 55a4ee5

File tree

11 files changed

+1466
-21
lines changed

11 files changed

+1466
-21
lines changed

sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import re
1919
import shutil
20+
import stat
2021
import subprocess
2122
from tempfile import TemporaryDirectory
2223
from typing import Any, Dict, List, Optional
@@ -57,6 +58,17 @@
5758
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
5859

5960

61+
def _rmtree(path):
62+
"""Remove a directory tree, handling root-owned files from Docker containers."""
63+
def _onerror(func, path, exc_info):
64+
if isinstance(exc_info[1], PermissionError):
65+
os.chmod(path, stat.S_IRWXU)
66+
func(path)
67+
else:
68+
raise exc_info[1]
69+
shutil.rmtree(path, onerror=_onerror)
70+
71+
6072
class _LocalContainer(BaseModel):
6173
"""A local training job class for local mode model trainer.
6274
@@ -209,12 +221,12 @@ def train(
209221
# Print our Job Complete line
210222
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
211223

212-
shutil.rmtree(os.path.join(self.container_root, "input"))
213-
shutil.rmtree(os.path.join(self.container_root, "shared"))
224+
_rmtree(os.path.join(self.container_root, "input"))
225+
_rmtree(os.path.join(self.container_root, "shared"))
214226
for host in self.hosts:
215-
shutil.rmtree(os.path.join(self.container_root, host))
227+
_rmtree(os.path.join(self.container_root, host))
216228
for folder in self._temporary_folders:
217-
shutil.rmtree(os.path.join(self.container_root, folder))
229+
_rmtree(os.path.join(self.container_root, folder))
218230
return artifacts
219231

220232
def retrieve_artifacts(

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ def _normalize_outputs(self, outputs=None):
487487
# If the output's s3_uri is not an s3_uri, create one.
488488
parse_result = urlparse(output.s3_output.s3_uri)
489489
if parse_result.scheme != "s3":
490+
if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file":
491+
normalized_outputs.append(output)
492+
continue
490493
if _pipeline_config:
491494
s3_uri = Join(
492495
on="/",

sagemaker-core/tests/unit/test_processing.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,261 @@ def test_normalize_outputs_invalid_type(self, mock_session):
238238
processor._normalize_outputs(["invalid"])
239239

240240

241+
242+
243+
class TestBugConditionFileUriReplacedInLocalMode:
244+
"""Bug condition exploration test: file:// URIs should be preserved in local mode.
245+
246+
**Validates: Requirements 1.1, 1.2, 2.1, 2.2**
247+
248+
EXPECTED TO FAIL on unfixed code — failure confirms the bug exists.
249+
The bug is that _normalize_outputs() replaces file:// URIs with s3:// paths
250+
even when the session is a LocalSession (local_mode=True).
251+
"""
252+
253+
@pytest.fixture
254+
def local_mock_session(self):
255+
session = Mock()
256+
session.boto_session = Mock()
257+
session.boto_session.region_name = "us-west-2"
258+
session.sagemaker_client = Mock()
259+
session.default_bucket = Mock(return_value="default-bucket")
260+
session.default_bucket_prefix = "prefix"
261+
session.expand_role = Mock(side_effect=lambda x: x)
262+
session.sagemaker_config = {}
263+
session.local_mode = True
264+
return session
265+
266+
@pytest.mark.parametrize(
267+
"file_uri",
268+
[
269+
"file:///tmp/output",
270+
"file:///home/user/results",
271+
"file:///data/processed",
272+
],
273+
)
274+
def test_normalize_outputs_preserves_file_uri_in_local_mode(self, local_mock_session, file_uri):
275+
"""file:// URIs must be preserved when local_mode=True.
276+
277+
On unfixed code, _normalize_outputs replaces file:// URIs with
278+
s3://default-bucket/prefix/job-name/output/output-1, which is the bug.
279+
"""
280+
processor = Processor(
281+
role="arn:aws:iam::123456789012:role/SageMakerRole",
282+
image_uri="test-image:latest",
283+
instance_count=1,
284+
instance_type="ml.m5.xlarge",
285+
sagemaker_session=local_mock_session,
286+
)
287+
processor._current_job_name = "test-job"
288+
289+
s3_output = ProcessingS3Output(
290+
s3_uri=file_uri,
291+
local_path="/opt/ml/processing/output",
292+
s3_upload_mode="EndOfJob",
293+
)
294+
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]
295+
296+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
297+
result = processor._normalize_outputs(outputs)
298+
299+
assert len(result) == 1
300+
assert result[0].s3_output.s3_uri == file_uri, (
301+
f"Expected file:// URI to be preserved as '{file_uri}' in local mode, "
302+
f"but got '{result[0].s3_output.s3_uri}'"
303+
)
304+
305+
306+
class TestPreservationNonLocalFileBehavior:
307+
"""Preservation property tests: Non-local-file behavior must remain unchanged.
308+
309+
**Validates: Requirements 3.1, 3.2, 3.3, 3.4**
310+
311+
These tests capture baseline behavior on UNFIXED code. They MUST PASS on both
312+
unfixed and fixed code, confirming no regressions are introduced by the fix.
313+
"""
314+
315+
@pytest.fixture
316+
def session_local_mode_true(self):
317+
session = Mock()
318+
session.boto_session = Mock()
319+
session.boto_session.region_name = "us-west-2"
320+
session.sagemaker_client = Mock()
321+
session.default_bucket = Mock(return_value="default-bucket")
322+
session.default_bucket_prefix = "prefix"
323+
session.expand_role = Mock(side_effect=lambda x: x)
324+
session.sagemaker_config = {}
325+
session.local_mode = True
326+
return session
327+
328+
@pytest.fixture
329+
def session_local_mode_false(self):
330+
session = Mock()
331+
session.boto_session = Mock()
332+
session.boto_session.region_name = "us-west-2"
333+
session.sagemaker_client = Mock()
334+
session.default_bucket = Mock(return_value="default-bucket")
335+
session.default_bucket_prefix = "prefix"
336+
session.expand_role = Mock(side_effect=lambda x: x)
337+
session.sagemaker_config = {}
338+
session.local_mode = False
339+
return session
340+
341+
def _make_processor(self, session):
342+
processor = Processor(
343+
role="arn:aws:iam::123456789012:role/SageMakerRole",
344+
image_uri="test-image:latest",
345+
instance_count=1,
346+
instance_type="ml.m5.xlarge",
347+
sagemaker_session=session,
348+
)
349+
processor._current_job_name = "test-job"
350+
return processor
351+
352+
# --- Requirement 3.1: S3 URIs pass through unchanged regardless of local_mode ---
353+
354+
@pytest.mark.parametrize(
355+
"s3_uri,local_mode_fixture",
356+
[
357+
("s3://my-bucket/path", "session_local_mode_true"),
358+
("s3://my-bucket/path", "session_local_mode_false"),
359+
("s3://another-bucket/deep/nested/path", "session_local_mode_true"),
360+
("s3://another-bucket/deep/nested/path", "session_local_mode_false"),
361+
],
362+
)
363+
def test_s3_uri_preserved_regardless_of_local_mode(self, s3_uri, local_mode_fixture, request):
364+
"""S3 URIs must pass through unchanged regardless of local_mode setting.
365+
366+
**Validates: Requirements 3.1**
367+
"""
368+
session = request.getfixturevalue(local_mode_fixture)
369+
processor = self._make_processor(session)
370+
371+
s3_output = ProcessingS3Output(
372+
s3_uri=s3_uri,
373+
local_path="/opt/ml/processing/output",
374+
s3_upload_mode="EndOfJob",
375+
)
376+
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]
377+
378+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
379+
result = processor._normalize_outputs(outputs)
380+
381+
assert len(result) == 1
382+
assert result[0].s3_output.s3_uri == s3_uri
383+
384+
# --- Requirement 3.2: Non-S3 URIs with local_mode=False replaced with S3 paths ---
385+
386+
@pytest.mark.parametrize(
387+
"non_s3_uri",
388+
[
389+
"/local/output/path",
390+
"http://example.com/output",
391+
"ftp://server/output",
392+
],
393+
)
394+
def test_non_s3_uri_replaced_when_not_local_mode(self, non_s3_uri, session_local_mode_false):
395+
"""Non-S3 URIs in non-local sessions are replaced with auto-generated S3 paths.
396+
397+
**Validates: Requirements 3.2**
398+
"""
399+
processor = self._make_processor(session_local_mode_false)
400+
401+
s3_output = ProcessingS3Output(
402+
s3_uri=non_s3_uri,
403+
local_path="/opt/ml/processing/output",
404+
s3_upload_mode="EndOfJob",
405+
)
406+
outputs = [ProcessingOutput(output_name="output-1", s3_output=s3_output)]
407+
408+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
409+
result = processor._normalize_outputs(outputs)
410+
411+
assert len(result) == 1
412+
assert result[0].s3_output.s3_uri.startswith("s3://default-bucket/")
413+
414+
# --- Requirement 3.3: Pipeline variable URIs skip normalization ---
415+
416+
def test_pipeline_variable_uri_skips_normalization(self, session_local_mode_false):
417+
"""Pipeline variable URIs skip normalization entirely.
418+
419+
**Validates: Requirements 3.3**
420+
"""
421+
processor = self._make_processor(session_local_mode_false)
422+
423+
s3_output = ProcessingS3Output(
424+
s3_uri="s3://bucket/output",
425+
local_path="/opt/ml/processing/output",
426+
s3_upload_mode="EndOfJob",
427+
)
428+
outputs = [ProcessingOutput(output_name="output-1", s3_output=s3_output)]
429+
430+
with patch("sagemaker.core.processing.is_pipeline_variable", return_value=True):
431+
result = processor._normalize_outputs(outputs)
432+
433+
assert len(result) == 1
434+
# Pipeline variable outputs are appended as-is without URI modification
435+
assert result[0].s3_output.s3_uri == "s3://bucket/output"
436+
437+
# --- Requirement 3.4: Non-ProcessingOutput objects raise TypeError ---
438+
439+
@pytest.mark.parametrize(
440+
"invalid_output",
441+
[
442+
["a string"],
443+
[42],
444+
[{"key": "value"}],
445+
],
446+
)
447+
def test_non_processing_output_raises_type_error(self, invalid_output, session_local_mode_false):
448+
"""Non-ProcessingOutput objects must raise TypeError.
449+
450+
**Validates: Requirements 3.4**
451+
"""
452+
processor = self._make_processor(session_local_mode_false)
453+
454+
with pytest.raises(TypeError, match="must be provided as ProcessingOutput objects"):
455+
processor._normalize_outputs(invalid_output)
456+
457+
# --- Output name auto-generation ---
458+
459+
def test_multiple_outputs_with_s3_uris_preserved(self, session_local_mode_false):
460+
"""Multiple outputs with S3 URIs are all preserved unchanged.
461+
462+
**Validates: Requirements 3.1, 3.2**
463+
"""
464+
processor = self._make_processor(session_local_mode_false)
465+
466+
outputs = [
467+
ProcessingOutput(
468+
output_name="first-output",
469+
s3_output=ProcessingS3Output(
470+
s3_uri="s3://my-bucket/first",
471+
local_path="/opt/ml/processing/output1",
472+
s3_upload_mode="EndOfJob",
473+
),
474+
),
475+
ProcessingOutput(
476+
output_name="second-output",
477+
s3_output=ProcessingS3Output(
478+
s3_uri="s3://my-bucket/second",
479+
local_path="/opt/ml/processing/output2",
480+
s3_upload_mode="EndOfJob",
481+
),
482+
),
483+
]
484+
485+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
486+
result = processor._normalize_outputs(outputs)
487+
488+
assert len(result) == 2
489+
assert result[0].output_name == "first-output"
490+
assert result[1].output_name == "second-output"
491+
# S3 URIs should be preserved since they already have s3:// scheme
492+
assert result[0].s3_output.s3_uri == "s3://my-bucket/first"
493+
assert result[1].s3_output.s3_uri == "s3://my-bucket/second"
494+
495+
241496
class TestProcessorStartNew:
242497
def test_start_new_with_pipeline_session(self, mock_session):
243498
from sagemaker.core.workflow.pipeline_context import PipelineSession

sagemaker-serve/src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from __future__ import absolute_import
44
from pathlib import Path
55
import logging
6+
import os
67
from datetime import datetime, timedelta
78
from typing import Dict, Type
89
import base64
910
import time
1011
import subprocess
1112
import docker
1213

14+
from sagemaker.core.local.utils import check_for_studio
15+
1316
from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
1417
from sagemaker.serve.spec.inference_spec import InferenceSpec
1518
from sagemaker.serve.builder.schema_builder import SchemaBuilder
@@ -33,6 +36,25 @@
3336
+ "Please increase container_timeout_seconds or review your inference code."
3437
)
3538

39+
STUDIO_DOCKER_SOCKET_PATHS = [
40+
"/docker/proxy/docker.sock",
41+
"/var/run/docker.sock",
42+
]
43+
44+
45+
def _get_docker_client():
46+
"""Get a Docker client, handling SageMaker Studio's non-standard socket path."""
47+
if os.environ.get("DOCKER_HOST"):
48+
return docker.from_env()
49+
try:
50+
if check_for_studio():
51+
for socket_path in STUDIO_DOCKER_SOCKET_PATHS:
52+
if os.path.exists(socket_path):
53+
return docker.DockerClient(base_url=f"unix://{socket_path}")
54+
except (NotImplementedError, Exception):
55+
pass
56+
return docker.from_env()
57+
3658

3759
class LocalContainerMode(
3860
LocalTorchServe,
@@ -212,7 +234,7 @@ def _pull_image(self, image: str):
212234

213235
# Check if Docker is available first
214236
try:
215-
self.client = docker.from_env()
237+
self.client = _get_docker_client()
216238
self.client.ping() # Test Docker connection
217239
except Exception as e:
218240
raise RuntimeError(

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
134134
MODEL_CONTAINERS_PATH,
135135
)
136-
from sagemaker.serve.constants import SUPPORTED_MODEL_SERVERS, Framework
136+
from sagemaker.serve.constants import LOCAL_MODES, SUPPORTED_MODEL_SERVERS, Framework
137137
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
138138
from sagemaker.core import fw_utils
139139
from sagemaker.core.helper.session_helper import container_def
@@ -1287,7 +1287,16 @@ def _build_for_passthrough(self) -> Model:
12871287
if not self.image_uri:
12881288
raise ValueError("image_uri is required for pass-through cases")
12891289

1290-
self.s3_upload_path = None
1290+
self.secret_key = ""
1291+
1292+
if self.model_path and self.model_path.startswith("s3://"):
1293+
self.s3_upload_path = self.model_path
1294+
else:
1295+
self.s3_upload_path = None
1296+
1297+
if self.mode in LOCAL_MODES:
1298+
self._prepare_for_mode()
1299+
12911300
return self._create_model()
12921301

12931302
def _build_default_async_inference_config(self, async_inference_config):

0 commit comments

Comments
 (0)