Skip to content

Commit 1ea0cbb

Browse files
author
Raul Diaz Garcia
committed
feature: allow custom runproc.sh in FrameworkProcessor
1 parent 91ca011 commit 1ea0cbb

2 files changed

Lines changed: 118 additions & 5 deletions

File tree

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def run(
11891189
job_name: Optional[str] = None,
11901190
experiment_config: Optional[Dict[str, str]] = None,
11911191
kms_key: Optional[str] = None,
1192+
entry_point: Optional[str] = None
11921193
):
11931194
"""Runs a processing job.
11941195
@@ -1216,6 +1217,9 @@ def run(
12161217
experiment_config (dict[str, str]): Experiment management configuration.
12171218
kms_key (str): The ARN of the KMS key that is used to encrypt the
12181219
user code file (default: None).
1220+
entry_point (str): Path (absolute or relative) to a custom entrypoint script file
1221+
(e.g., runproc.sh). The python script call is appended automatically.
1222+
12191223
Returns:
12201224
None or pipeline step arguments in case the Processor instance is built with
12211225
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
@@ -1227,6 +1231,7 @@ def run(
12271231
job_name,
12281232
inputs,
12291233
kms_key,
1234+
entry_point
12301235
)
12311236

12321237
# Submit a processing job.
@@ -1250,6 +1255,7 @@ def _pack_and_upload_code(
12501255
job_name,
12511256
inputs,
12521257
kms_key=None,
1258+
entry_point=None
12531259
):
12541260
"""Pack local code bundle and upload to Amazon S3."""
12551261
if code.startswith("s3://"):
@@ -1274,7 +1280,7 @@ def _pack_and_upload_code(
12741280
script = os.path.basename(code)
12751281
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
12761282
s3_runproc_sh = self._create_and_upload_runproc(
1277-
script, evaluated_kms_key, entrypoint_s3_uri
1283+
script, evaluated_kms_key, entrypoint_s3_uri, entry_point, source_dir
12781284
)
12791285

12801286
return s3_runproc_sh, inputs, job_name
@@ -1312,12 +1318,12 @@ def _set_entrypoint(self, command, user_script_name):
13121318
)
13131319
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
13141320

1315-
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
1321+
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri, entry_point=None, source_dir=None):
13161322
"""Create runproc shell script and upload to S3 bucket."""
13171323
from sagemaker.core.workflow.utilities import _pipeline_config, hash_object
13181324

13191325
if _pipeline_config and _pipeline_config.pipeline_name:
1320-
runproc_file_str = self._generate_framework_script(user_script)
1326+
runproc_file_str = self._generate_framework_script(user_script, entry_point, source_dir)
13211327
runproc_file_hash = hash_object(runproc_file_str)
13221328
s3_uri = s3.s3_path_join(
13231329
"s3://",
@@ -1336,16 +1342,19 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
13361342
)
13371343
else:
13381344
s3_runproc_sh = s3.S3Uploader.upload_string_as_file_body(
1339-
self._generate_framework_script(user_script),
1345+
self._generate_framework_script(user_script, entry_point, source_dir),
13401346
desired_s3_uri=entrypoint_s3_uri,
13411347
kms_key=kms_key,
13421348
sagemaker_session=self.sagemaker_session,
13431349
)
13441350

13451351
return s3_runproc_sh
13461352

1347-
def _generate_framework_script(self, user_script: str) -> str:
1353+
def _generate_framework_script(self, user_script: str, entry_point: str = None, source_dir: str = None) -> str:
13481354
"""Generate the framework entrypoint file (as text) for a processing job."""
1355+
if entry_point:
1356+
return self._generate_custom_framework_script(user_script, entry_point, source_dir)
1357+
13491358
return dedent(
13501359
"""\
13511360
#!/bin/bash
@@ -1383,6 +1392,45 @@ def _generate_framework_script(self, user_script: str) -> str:
13831392
entry_point=user_script,
13841393
)
13851394

1395+
def _generate_custom_framework_script(
1396+
self, user_script: str, entry_point: str, source_dir: str = None
1397+
) -> str:
1398+
"""
1399+
Generate a custom framework script with a user-provided entrypoint embedded.
1400+
1401+
Reads the entry_point file and embeds its content in the script,
1402+
then appends the command to execute the user script.
1403+
1404+
Args:
1405+
user_script (str): Relative path to the user script in the source bundle
1406+
entry_point (str): Path to the custom entrypoint script file
1407+
source_dir (str): Path to the source directory. If provided and entry_point
1408+
is relative, it will be combined with source_dir.
1409+
1410+
Returns:
1411+
str: The generated script content
1412+
"""
1413+
# Resolve the full path to the entry_point file
1414+
if source_dir and not os.path.isabs(entry_point):
1415+
full_entry_point_path = os.path.join(source_dir, entry_point)
1416+
else:
1417+
full_entry_point_path = entry_point
1418+
1419+
# Read the entry_point file content
1420+
with open(full_entry_point_path, "r", encoding="utf-8") as f:
1421+
entry_point_content = f.read()
1422+
1423+
# Generate the script with embedded entry_point content
1424+
return dedent("""\
1425+
{entry_point_content}
1426+
1427+
{entry_point_command} {entry_point} "$@"
1428+
""").format(
1429+
entry_point_content=entry_point_content,
1430+
entry_point_command=" ".join(self.command),
1431+
entry_point=user_script,
1432+
)
1433+
13861434

13871435
class FeatureStoreOutput(ApiObject):
13881436
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""

sagemaker-core/tests/unit/test_processing.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,71 @@ def test_create_and_upload_runproc_without_pipeline(self, mock_session):
863863
)
864864
assert result == "s3://bucket/runproc.sh"
865865

866+
def test_generate_framework_script_with_custom_entry_point(self, mock_session):
867+
processor = FrameworkProcessor(
868+
role="arn:aws:iam::123456789012:role/SageMakerRole",
869+
image_uri="test-image:latest",
870+
command=["python3"],
871+
instance_count=1,
872+
instance_type="ml.m5.xlarge",
873+
sagemaker_session=mock_session,
874+
)
875+
876+
custom_script_content = "#!/bin/bash\necho 'THIS IS THE CUSTOM runproc.sh'\nset -e\n"
877+
878+
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
879+
f.write(custom_script_content)
880+
entry_point_path = f.name
881+
882+
try:
883+
script = processor._generate_framework_script(
884+
"train.py", entry_point=entry_point_path
885+
)
886+
assert custom_script_content in script
887+
assert "python3 train.py" in script
888+
assert "tar -xzf sourcedir.tar.gz" not in script
889+
finally:
890+
os.unlink(entry_point_path)
891+
892+
def test_generate_framework_script_with_custom_entry_point_and_source_dir(self, mock_session):
893+
processor = FrameworkProcessor(
894+
role="arn:aws:iam::123456789012:role/SageMakerRole",
895+
image_uri="test-image:latest",
896+
command=["python3"],
897+
instance_count=1,
898+
instance_type="ml.m5.xlarge",
899+
sagemaker_session=mock_session,
900+
)
901+
902+
with tempfile.TemporaryDirectory() as tmpdir:
903+
custom_script_content = "#!/bin/bash\necho 'custom from source_dir'\n"
904+
script_path = os.path.join(tmpdir, "custom_runproc.sh")
905+
with open(script_path, "w") as f:
906+
f.write(custom_script_content)
907+
908+
script = processor._generate_framework_script(
909+
"train.py",
910+
entry_point="custom_runproc.sh",
911+
source_dir=tmpdir,
912+
)
913+
assert custom_script_content in script
914+
assert "python3 train.py" in script
915+
916+
def test_generate_framework_script_with_default_entry_point(self, mock_session):
917+
processor = FrameworkProcessor(
918+
role="arn:aws:iam::123456789012:role/SageMakerRole",
919+
image_uri="test-image:latest",
920+
command=["python3"],
921+
instance_count=1,
922+
instance_type="ml.m5.xlarge",
923+
sagemaker_session=mock_session,
924+
)
925+
926+
script = processor._generate_framework_script("train.py")
927+
assert "#!/bin/bash" in script
928+
assert "tar -xzf sourcedir.tar.gz" in script
929+
assert "python3 train.py" in script
930+
866931

867932
class TestHelperFunctions:
868933
def test_processing_input_to_request_dict(self):

0 commit comments

Comments
 (0)