Skip to content

Commit e87f985

Browse files
author
Namrata Madan
committed
fix: rename signing key
1 parent c6082ea commit e87f985

10 files changed

Lines changed: 115 additions & 54 deletions

File tree

requirements/extras/test_requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ graphene
1313
typing_extensions>=4.9.0
1414
tensorflow>=2.16.2,<=2.19.0
1515
build
16-
docker>=5.0.2,<8.0.0
16+
docker>=5.0.2,<8.0
17+
filelock>=3.0.0

sagemaker-core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"cloudpickle>=2.0.0",
3939
"paramiko>=2.11.0",
4040
"tblib>=1.7.0",
41+
"cryptography>=46.0.0",
4142
]
4243
requires-python = ">=3.9"
4344
classifiers = [

sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self,
5656
sagemaker_session: Session,
5757
s3_base_uri: str,
58-
hmac_key: str,
58+
signing_key: str,
5959
s3_kms_key: str = None,
6060
context: Context = Context(),
6161
):
@@ -66,13 +66,13 @@ def __init__(
6666
AWS service calls are delegated to.
6767
s3_base_uri: the base uri to which serialized artifacts will be uploaded.
6868
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
69-
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
69+
signing_key: Key used to encrypt serialized and deserialized function and arguments.
7070
context: Build or run context of a pipeline step.
7171
"""
7272
self.sagemaker_session = sagemaker_session
7373
self.s3_base_uri = s3_base_uri
7474
self.s3_kms_key = s3_kms_key
75-
self.hmac_key = hmac_key
75+
self.signing_key = signing_key
7676
self.context = context
7777

7878
# For pipeline steps, function code is at: base/step_name/build_timestamp/
@@ -114,7 +114,7 @@ def save(self, func, *args, **kwargs):
114114
sagemaker_session=self.sagemaker_session,
115115
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
116116
s3_kms_key=self.s3_kms_key,
117-
private_key=self.hmac_key,
117+
private_key=self.signing_key,
118118
)
119119

120120
logger.info(
@@ -126,7 +126,7 @@ def save(self, func, *args, **kwargs):
126126
obj=(args, kwargs),
127127
sagemaker_session=self.sagemaker_session,
128128
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
129-
signing_key=self.hmac_key,
129+
signing_key=self.signing_key,
130130
s3_kms_key=self.s3_kms_key,
131131
)
132132

@@ -144,7 +144,7 @@ def save_pipeline_step_function(self, serialized_data):
144144
)
145145
serialization._upload_payload_and_metadata_to_s3_signed(
146146
bytes_to_upload=serialized_data.func,
147-
private_key=self.hmac_key,
147+
private_key=self.signing_key,
148148
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
149149
sagemaker_session=self.sagemaker_session,
150150
s3_kms_key=self.s3_kms_key,
@@ -156,7 +156,7 @@ def save_pipeline_step_function(self, serialized_data):
156156
)
157157
serialization._upload_payload_and_metadata_to_s3_signed(
158158
bytes_to_upload=serialized_data.args,
159-
private_key=self.hmac_key,
159+
private_key=self.signing_key,
160160
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
161161
sagemaker_session=self.sagemaker_session,
162162
s3_kms_key=self.s3_kms_key,
@@ -172,7 +172,7 @@ def load_and_invoke(self) -> Any:
172172
func = serialization.deserialize_func_from_s3(
173173
sagemaker_session=self.sagemaker_session,
174174
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
175-
public_key_pem=self.hmac_key,
175+
public_key_pem=self.signing_key,
176176
)
177177

178178
logger.info(
@@ -182,7 +182,7 @@ def load_and_invoke(self) -> Any:
182182
args, kwargs = serialization.deserialize_obj_from_s3(
183183
sagemaker_session=self.sagemaker_session,
184184
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
185-
verification_key=self.hmac_key,
185+
verification_key=self.signing_key,
186186
)
187187

188188
logger.info("Resolving pipeline variables")

sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _load_pipeline_context(args) -> Context:
9898

9999

100100
def _execute_remote_function(
101-
sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context
101+
sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, signing_key, context
102102
):
103103
"""Execute stored remote function"""
104104
from sagemaker.core.remote_function.core.stored_function import StoredFunction
@@ -107,7 +107,7 @@ def _execute_remote_function(
107107
sagemaker_session=sagemaker_session,
108108
s3_base_uri=s3_base_uri,
109109
s3_kms_key=s3_kms_key,
110-
hmac_key=hmac_key,
110+
signing_key=signing_key,
111111
context=context,
112112
)
113113

@@ -138,15 +138,15 @@ def main(sys_args=None):
138138
run_in_context = args.run_in_context
139139
pipeline_context = _load_pipeline_context(args)
140140

141-
hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY")
141+
signing_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY")
142142

143143
sagemaker_session = _get_sagemaker_session(region)
144144
_execute_remote_function(
145145
sagemaker_session=sagemaker_session,
146146
s3_base_uri=s3_base_uri,
147147
s3_kms_key=s3_kms_key,
148148
run_in_context=run_in_context,
149-
hmac_key=hmac_key,
149+
signing_key=signing_key,
150150
context=pipeline_context,
151151
)
152152

sagemaker-core/src/sagemaker/core/remote_function/job.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -842,19 +842,19 @@ def _get_default_spark_image(session):
842842
class _Job:
843843
"""Helper class that interacts with the SageMaker training service."""
844844

845-
def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str):
845+
def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, verification_key: str):
846846
"""Initialize a _Job object.
847847
848848
Args:
849849
job_name (str): The training job name.
850850
s3_uri (str): The training job output S3 uri.
851851
sagemaker_session (Session): SageMaker boto session.
852-
hmac_key (str): Remote function secret key.
852+
verification_key (str): Remote function secret key.
853853
"""
854854
self.job_name = job_name
855855
self.s3_uri = s3_uri
856856
self.sagemaker_session = sagemaker_session
857-
self.hmac_key = hmac_key
857+
self.verification_key = verification_key
858858
self._last_describe_response = None
859859

860860
@staticmethod
@@ -870,9 +870,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
870870
"""
871871
job_name = describe_training_job_response["TrainingJobName"]
872872
s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
873-
hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"]
873+
verification_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"]
874874

875-
job = _Job(job_name, s3_uri, sagemaker_session, hmac_key)
875+
job = _Job(job_name, s3_uri, sagemaker_session, verification_key)
876876
job._last_describe_response = describe_training_job_response
877877
return job
878878

@@ -965,15 +965,15 @@ def compile(
965965
stored_function = StoredFunction(
966966
sagemaker_session=job_settings.sagemaker_session,
967967
s3_base_uri=s3_base_uri,
968-
hmac_key=private_key,
968+
signing_key=private_key,
969969
s3_kms_key=job_settings.s3_kms_key,
970970
)
971971
stored_function.save(func, *func_args, **func_kwargs)
972972
else:
973973
stored_function = StoredFunction(
974974
sagemaker_session=job_settings.sagemaker_session,
975975
s3_base_uri=s3_base_uri,
976-
hmac_key=private_key,
976+
signing_key=private_key,
977977
s3_kms_key=job_settings.s3_kms_key,
978978
context=Context(
979979
step_name=step_compilation_context.step_name,

sagemaker-core/tests/integ/remote_function/conftest.py

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from contextlib import contextmanager
2222

2323
import docker
24+
import filelock
2425
import pytest
2526
from docker.errors import BuildError
2627

@@ -149,45 +150,52 @@ def gpu_instance_type():
149150

150151

151152
@pytest.fixture(scope="session")
152-
def dummy_container_without_error(sagemaker_session, compatible_python_version):
153-
ecr_uri = _build_container(sagemaker_session, compatible_python_version, DOCKERFILE_TEMPLATE)
154-
return ecr_uri
153+
def dummy_container_without_error(sagemaker_session, compatible_python_version, sagemaker_sdk_tar_path, tmp_path_factory):
154+
return _build_container_once(
155+
"dummy_container_without_error", sagemaker_session, compatible_python_version,
156+
DOCKERFILE_TEMPLATE, sagemaker_sdk_tar_path, tmp_path_factory,
157+
)
155158

156159

157160
@pytest.fixture(scope="session")
158-
def dummy_container_with_user_and_workdir(sagemaker_session, compatible_python_version):
159-
ecr_uri = _build_container(
160-
sagemaker_session,
161-
compatible_python_version,
162-
DOCKERFILE_TEMPLATE_WITH_USER_AND_WORKDIR,
161+
def dummy_container_with_user_and_workdir(sagemaker_session, compatible_python_version, sagemaker_sdk_tar_path, tmp_path_factory):
162+
return _build_container_once(
163+
"dummy_container_with_user_and_workdir", sagemaker_session, compatible_python_version,
164+
DOCKERFILE_TEMPLATE_WITH_USER_AND_WORKDIR, sagemaker_sdk_tar_path, tmp_path_factory,
163165
)
164-
return ecr_uri
165166

166167

167168
@pytest.fixture(scope="session")
168-
def dummy_container_incompatible_python_runtime(sagemaker_session, incompatible_python_version):
169-
ecr_uri = _build_container(sagemaker_session, incompatible_python_version, DOCKERFILE_TEMPLATE)
170-
return ecr_uri
169+
def dummy_container_incompatible_python_runtime(sagemaker_session, incompatible_python_version, sagemaker_sdk_tar_path, tmp_path_factory):
170+
return _build_container_once(
171+
"dummy_container_incompatible_python_runtime", sagemaker_session, incompatible_python_version,
172+
DOCKERFILE_TEMPLATE, sagemaker_sdk_tar_path, tmp_path_factory,
173+
)
171174

172175

173176
@pytest.fixture(scope="session")
174-
def dummy_container_with_conda(sagemaker_session, compatible_python_version):
175-
ecr_uri = _build_container(
176-
sagemaker_session, compatible_python_version, DOCKERFILE_TEMPLATE_WITH_CONDA
177+
def dummy_container_with_conda(sagemaker_session, compatible_python_version, sagemaker_sdk_tar_path, tmp_path_factory):
178+
return _build_container_once(
179+
"dummy_container_with_conda", sagemaker_session, compatible_python_version,
180+
DOCKERFILE_TEMPLATE_WITH_CONDA, sagemaker_sdk_tar_path, tmp_path_factory,
177181
)
178-
return ecr_uri
179182

180183

181184
@pytest.fixture(scope="session")
182-
def auto_capture_test_container(sagemaker_session):
183-
ecr_uri = _build_auto_capture_client_container("3.10", AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE)
184-
return ecr_uri
185+
def auto_capture_test_container(sagemaker_session, sagemaker_sdk_tar_path, tmp_path_factory):
186+
return _build_container_once(
187+
"auto_capture_test_container", sagemaker_session, "3.10",
188+
AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE, sagemaker_sdk_tar_path, tmp_path_factory,
189+
is_auto_capture=True,
190+
)
185191

186192

187193
@pytest.fixture(scope="session")
188-
def spark_test_container(sagemaker_session):
189-
ecr_uri = _build_container(sagemaker_session, "3.9", DOCKERFILE_TEMPLATE)
190-
return ecr_uri
194+
def spark_test_container(sagemaker_session, sagemaker_sdk_tar_path, tmp_path_factory):
195+
return _build_container_once(
196+
"spark_test_container", sagemaker_session, "3.9",
197+
DOCKERFILE_TEMPLATE, sagemaker_sdk_tar_path, tmp_path_factory,
198+
)
191199

192200

193201
@pytest.fixture(scope="session")
@@ -208,6 +216,27 @@ def conda_env_yml():
208216
os.remove(conda_yml_file_name)
209217

210218

219+
@pytest.fixture(scope="session")
220+
def sagemaker_sdk_tar_path(tmp_path_factory):
221+
"""Build the sagemaker-core sdist once and share it across all xdist workers.
222+
223+
Uses a file lock so only one worker runs the build; others wait and reuse
224+
the already-built tar.gz from the shared temp directory.
225+
"""
226+
# tmp_path_factory.getbasetemp().parent is shared across all xdist workers
227+
root_tmp = tmp_path_factory.getbasetemp().parent
228+
tar_dir = root_tmp / "sagemaker_sdk_tar"
229+
tar_dir.mkdir(exist_ok=True)
230+
lock_file = root_tmp / "sagemaker_sdk_tar.lock"
231+
232+
with filelock.FileLock(str(lock_file)):
233+
existing = list(tar_dir.glob("*.tar.gz"))
234+
if not existing:
235+
_generate_sagemaker_sdk_tar(str(tar_dir))
236+
existing = list(tar_dir.glob("*.tar.gz"))
237+
return str(existing[0])
238+
239+
211240
def _tmpdir():
212241
"""Create a temporary directory context manager."""
213242
import tempfile
@@ -222,7 +251,33 @@ def _tmpdir():
222251
_tmpdir = contextmanager(_tmpdir)
223252

224253

225-
def _build_container(sagemaker_session, py_version, docker_template):
254+
def _build_container_once(
255+
fixture_name, sagemaker_session, py_version, docker_template, sdk_tar_path,
256+
tmp_path_factory, is_auto_capture=False,
257+
):
258+
"""Build and push a container image exactly once across all xdist workers.
259+
260+
Uses a file lock keyed by fixture_name so parallel workers wait for the
261+
first worker to finish, then reuse the ECR URI written to a shared file.
262+
"""
263+
root_tmp = tmp_path_factory.getbasetemp().parent
264+
uri_file = root_tmp / f"{fixture_name}.ecr_uri"
265+
lock_file = root_tmp / f"{fixture_name}.lock"
266+
267+
with filelock.FileLock(str(lock_file)):
268+
if uri_file.exists():
269+
return uri_file.read_text().strip()
270+
if is_auto_capture:
271+
ecr_uri = _build_auto_capture_client_container(
272+
py_version, docker_template, sdk_tar_path
273+
)
274+
else:
275+
ecr_uri = _build_container(sagemaker_session, py_version, docker_template, sdk_tar_path)
276+
uri_file.write_text(ecr_uri)
277+
return ecr_uri
278+
279+
280+
def _build_container(sagemaker_session, py_version, docker_template, sdk_tar_path):
226281
"""Build a dummy test container locally and push to ECR."""
227282
region = sagemaker_session.boto_region_name
228283
image_tag = f"{py_version.replace('.', '-')}-{sagemaker_timestamp()}"
@@ -231,7 +286,8 @@ def _build_container(sagemaker_session, py_version, docker_template):
231286

232287
with _tmpdir() as tmpdir:
233288
print("building docker image locally in ", tmpdir)
234-
source_archive = _generate_sagemaker_sdk_tar(tmpdir)
289+
source_archive = os.path.basename(sdk_tar_path)
290+
shutil.copy2(sdk_tar_path, os.path.join(tmpdir, source_archive))
235291
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
236292
content = docker_template.format(py_version=py_version, source_archive=source_archive)
237293
print(f"Dockerfile contents: \n{content}\n")
@@ -267,10 +323,11 @@ def _build_container(sagemaker_session, py_version, docker_template):
267323
return ecr_image
268324

269325

270-
def _build_auto_capture_client_container(py_version, docker_template):
326+
def _build_auto_capture_client_container(py_version, docker_template, sdk_tar_path):
271327
"""Build a test docker container for auto_capture tests."""
272328
with _tmpdir() as tmpdir:
273-
source_archive = _generate_sdk_tar_with_public_version(tmpdir)
329+
source_archive = os.path.basename(sdk_tar_path)
330+
shutil.copy2(sdk_tar_path, os.path.join(tmpdir, source_archive))
274331
_move_auto_capture_test_file(tmpdir)
275332
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
276333
content = docker_template.format(py_version=py_version, source_archive=source_archive)
@@ -304,7 +361,11 @@ def _ecr_image_uri(account, region, image_name, tag):
304361
def _generate_sagemaker_sdk_tar(destination_folder):
305362
"""Run build to generate the SDK tar file."""
306363
command = f"python -m build --sdist -o {destination_folder}"
307-
result = subprocess.run(command, shell=True, check=True, capture_output=True)
364+
try:
365+
subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
366+
except subprocess.CalledProcessError as e:
367+
print(f"Error when building sagemaker-core sdist: {e.stderr}")
368+
raise
308369
destination_folder_contents = os.listdir(destination_folder)
309370
source_archive = [f for f in destination_folder_contents if f.endswith("tar.gz")][0]
310371
return source_archive

sagemaker-core/tests/unit/remote_function/test_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_init(self, mock_session):
146146
job = _Job("test-job", "s3://bucket/output", mock_session, "test-key")
147147
assert job.job_name == "test-job"
148148
assert job.s3_uri == "s3://bucket/output"
149-
assert job.hmac_key == "test-key"
149+
assert job.verification_key == "test-key"
150150

151151
def test_from_describe_response(self, mock_session):
152152
"""Test creating _Job from describe response."""
@@ -158,7 +158,7 @@ def test_from_describe_response(self, mock_session):
158158
job = _Job.from_describe_response(response, mock_session)
159159
assert job.job_name == "test-job"
160160
assert job.s3_uri == "s3://bucket/output"
161-
assert job.hmac_key == "test-key"
161+
assert job.verification_key == "test-key"
162162

163163
def test_describe_returns_cached_response(self, mock_session):
164164
"""Test that describe returns cached response for completed jobs."""

sagemaker-core/tests/unit/remote_function/test_job_comprehensive.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ def test_from_describe_response(self, mock_session):
149149
job = _Job.from_describe_response(response, mock_session)
150150
assert job.job_name == "test-job"
151151
assert job.s3_uri == "s3://bucket/output"
152-
assert job.hmac_key == "test-key"
153-
assert job._last_describe_response == response
152+
assert job.verification_key == "test-key"
154153

155154
def test_describe_cached_completed(self, mock_session):
156155
"""Test lines 865-871: describe with cached completed job."""

0 commit comments

Comments
 (0)