Skip to content

Commit b3301d1

Browse files
committed
fix: wire FrameworkProcessor code_location into code upload paths
`FrameworkProcessor.__init__` accepts `code_location` and the docstring states it controls where code is uploaded, but `_package_code` and `_create_and_upload_runproc` always used `default_bucket()`, ignoring it (Bug 2 reported in #5765) - Add `_s3_code_prefix()` helper that returns `code_location` when set, falling back to `default_bucket()`/`default_bucket_prefix` - Use `_s3_code_prefix()` in `_package_code` for `sourcedir.tar.gz` upload - Use `_s3_code_prefix()` in `_create_and_upload_runproc` pipeline path for `runproc.sh` upload - Non-pipeline `runproc.sh` and `_patch_inputs_with_payload` already derive their URIs from `_package_code`'s output, so they inherit the fix
1 parent 98683ac commit b3301d1

2 files changed

Lines changed: 88 additions & 6 deletions

File tree

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,16 @@ def __init__(
11191119
code_location[:-1] if (code_location and code_location.endswith("/")) else code_location
11201120
)
11211121

1122+
def _s3_code_prefix(self):
1123+
"""Return the S3 prefix for code uploads, respecting code_location if set."""
1124+
if self.code_location:
1125+
return self.code_location
1126+
return s3.s3_path_join(
1127+
"s3://",
1128+
self.sagemaker_session.default_bucket(),
1129+
self.sagemaker_session.default_bucket_prefix or "",
1130+
)
1131+
11221132
def _package_code(
11231133
self,
11241134
entry_point,
@@ -1155,9 +1165,7 @@ def _package_code(
11551165

11561166
# Upload to S3
11571167
s3_uri = s3.s3_path_join(
1158-
"s3://",
1159-
self.sagemaker_session.default_bucket(),
1160-
self.sagemaker_session.default_bucket_prefix or "",
1168+
self._s3_code_prefix(),
11611169
job_name,
11621170
"source",
11631171
"sourcedir.tar.gz",
@@ -1320,9 +1328,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
13201328
runproc_file_str = self._generate_framework_script(user_script)
13211329
runproc_file_hash = hash_object(runproc_file_str)
13221330
s3_uri = s3.s3_path_join(
1323-
"s3://",
1324-
self.sagemaker_session.default_bucket(),
1325-
self.sagemaker_session.default_bucket_prefix,
1331+
self._s3_code_prefix(),
13261332
_pipeline_config.pipeline_name,
13271333
"code",
13281334
runproc_file_hash,

sagemaker-core/tests/unit/test_processing.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,82 @@ def test_package_code_source_dir_not_exists(self, mock_session):
11541154
)
11551155

11561156

1157+
def test_package_code_with_code_location(self, mock_session):
1158+
processor = FrameworkProcessor(
1159+
role="arn:aws:iam::123456789012:role/SageMakerRole",
1160+
image_uri="test-image:latest",
1161+
instance_count=1,
1162+
instance_type="ml.m5.xlarge",
1163+
sagemaker_session=mock_session,
1164+
code_location="s3://my-custom-bucket/my-prefix",
1165+
)
1166+
1167+
with tempfile.TemporaryDirectory() as tmpdir:
1168+
entry_point = os.path.join(tmpdir, "train.py")
1169+
with open(entry_point, "w") as f:
1170+
f.write("print('training')")
1171+
1172+
result = processor._package_code(
1173+
entry_point=entry_point,
1174+
source_dir=tmpdir,
1175+
requirements=None,
1176+
job_name="test-job",
1177+
kms_key=None,
1178+
)
1179+
assert result.startswith("s3://my-custom-bucket/my-prefix")
1180+
assert "sourcedir.tar.gz" in result
1181+
1182+
def test_package_code_without_code_location_uses_default_bucket(self, mock_session):
1183+
processor = FrameworkProcessor(
1184+
role="arn:aws:iam::123456789012:role/SageMakerRole",
1185+
image_uri="test-image:latest",
1186+
instance_count=1,
1187+
instance_type="ml.m5.xlarge",
1188+
sagemaker_session=mock_session,
1189+
)
1190+
1191+
with tempfile.TemporaryDirectory() as tmpdir:
1192+
entry_point = os.path.join(tmpdir, "train.py")
1193+
with open(entry_point, "w") as f:
1194+
f.write("print('training')")
1195+
1196+
result = processor._package_code(
1197+
entry_point=entry_point,
1198+
source_dir=tmpdir,
1199+
requirements=None,
1200+
job_name="test-job",
1201+
kms_key=None,
1202+
)
1203+
assert result.startswith("s3://test-bucket/sagemaker")
1204+
assert "sourcedir.tar.gz" in result
1205+
1206+
def test_package_code_with_code_location_trailing_slash(self, mock_session):
1207+
processor = FrameworkProcessor(
1208+
role="arn:aws:iam::123456789012:role/SageMakerRole",
1209+
image_uri="test-image:latest",
1210+
instance_count=1,
1211+
instance_type="ml.m5.xlarge",
1212+
sagemaker_session=mock_session,
1213+
code_location="s3://my-custom-bucket/my-prefix/",
1214+
)
1215+
1216+
with tempfile.TemporaryDirectory() as tmpdir:
1217+
entry_point = os.path.join(tmpdir, "train.py")
1218+
with open(entry_point, "w") as f:
1219+
f.write("print('training')")
1220+
1221+
result = processor._package_code(
1222+
entry_point=entry_point,
1223+
source_dir=tmpdir,
1224+
requirements=None,
1225+
job_name="test-job",
1226+
kms_key=None,
1227+
)
1228+
# Trailing slash is stripped in __init__, so same result
1229+
assert result.startswith("s3://my-custom-bucket/my-prefix")
1230+
assert "sourcedir.tar.gz" in result
1231+
1232+
11571233
class TestFrameworkProcessorRun:
11581234
def test_run_with_s3_code(self, mock_session):
11591235
processor = FrameworkProcessor(

0 commit comments

Comments
 (0)