Skip to content

Commit 65d6c04

Browse files
authored
test: unskip two spark related tests (aws#5895)
* unskip two spark related tests * fix: fallback to Spark 3.5 for py312 when default Spark version lacks py312 support * fix: install local sagemaker-core wheel in Spark container for integ test The Spark processing image does not have sagemaker-core pre-installed. Build the local dev wheel, upload to S3, and install it in the container via pre_execution_commands, mirroring the pattern used in sagemaker-mlops feature_processor integ tests. * fix: replace assert with if/raise in remote spark test function Pytest's assertion rewriting injects _pytest module references into the function bytecode. When cloudpickle serializes the function and the Spark container deserializes it, it fails with ModuleNotFoundError: No module named '_pytest' since pytest is not installed in the container.
1 parent 24e6ef0 commit 65d6c04

4 files changed

Lines changed: 70 additions & 11 deletions

File tree

sagemaker-core/src/sagemaker/core/remote_function/job.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,10 @@ def _get_default_spark_image(session):
860860
except ImportError:
861861
pass
862862

863+
# Spark 3.3 and below do not support py312; use 3.5 which supports both py39 and py312
864+
if py_version == "312" and spark_version in ("2.4", "3.0", "3.1", "3.2", "3.3"):
865+
spark_version = "3.5"
866+
863867
image_uri = image_uris.retrieve(
864868
framework=SPARK_NAME,
865869
region=region,

sagemaker-core/tests/integ/remote_function/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,52 @@ def spark_test_container(sagemaker_session, sagemaker_sdk_tar_path, tmp_path_fac
171171
)
172172

173173

174+
@pytest.fixture(scope="session")
175+
def spark_pre_execution_commands(sagemaker_session):
176+
"""Build sagemaker-core wheel, upload to S3, and return pre-execution install commands.
177+
178+
This mirrors the pattern used in sagemaker-mlops feature_processor integ tests.
179+
The Spark processing image does not have sagemaker-core pre-installed, so we must
180+
build the local dev wheel and install it in the container via pre_execution_commands.
181+
"""
182+
import subprocess
183+
import glob
184+
import tempfile
185+
from sagemaker.core.s3 import S3Uploader
186+
187+
repo_root = os.path.abspath(
188+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
189+
)
190+
core_dir = os.path.join(repo_root, "sagemaker-core")
191+
192+
with tempfile.TemporaryDirectory() as dist_dir:
193+
subprocess.run(
194+
f"python -m build --wheel --outdir {dist_dir}",
195+
shell=True,
196+
cwd=core_dir,
197+
check=True,
198+
)
199+
wheels = glob.glob(os.path.join(dist_dir, "sagemaker_core-*.whl"))
200+
if not wheels:
201+
raise FileNotFoundError(f"No sagemaker-core wheel found in {dist_dir}")
202+
wheel_path = wheels[0]
203+
wheel_name = os.path.basename(wheel_path)
204+
205+
s3_prefix = "s3://{}/spark-integ-test/wheels".format(
206+
sagemaker_session.default_bucket()
207+
)
208+
S3Uploader.upload(wheel_path, s3_prefix, sagemaker_session=sagemaker_session)
209+
210+
PIP = "python3 -m pip install --root-user-action=ignore"
211+
AWS = "python3 -m awscli"
212+
cmds = [
213+
f"{PIP} awscli",
214+
f"{AWS} s3 cp {s3_prefix}/{wheel_name} /tmp/{wheel_name}",
215+
f"{PIP} /tmp/{wheel_name}",
216+
]
217+
return cmds
218+
219+
174220
@pytest.fixture(scope="session")
175221
def conda_env_yml():
176222
"""Write conda yml file needed for tests."""

sagemaker-core/tests/integ/remote_function/test_decorator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -574,16 +574,18 @@ def my_func():
574574
assert client_error_message in str(error)
575575

576576

577-
@pytest.mark.skipif(
578-
sys.version_info[:2] not in [(3, 9), (3, 12)],
579-
reason="SageMaker Spark image only available for Python 3.9 and 3.12",
580-
)
581-
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
577+
# @pytest.mark.skipif(
578+
# sys.version_info[:2] not in [(3, 9), (3, 12)],
579+
# reason="SageMaker Spark image only available for Python 3.9 and 3.12",
580+
# )
581+
@pytest.mark.spark_py312
582+
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type, spark_pre_execution_commands):
582583
@remote(
583584
role=ROLE,
584585
instance_type=cpu_instance_type,
585586
sagemaker_session=sagemaker_session,
586587
keep_alive_period_in_seconds=60,
588+
pre_execution_commands=spark_pre_execution_commands,
587589
spark_config=SparkConfig(
588590
configuration=[
589591
{
@@ -598,7 +600,14 @@ def test_spark_transform():
598600

599601
spark = SparkSession.builder.getOrCreate()
600602

601-
assert spark.conf.get("spark.app.name") == "remote-spark-test"
603+
# Avoid bare assert here: pytest's assertion rewriting injects _pytest
604+
# module references into the function bytecode, which causes
605+
# deserialization to fail in the Spark container (no pytest installed).
606+
app_name = spark.conf.get("spark.app.name")
607+
if app_name != "remote-spark-test":
608+
raise RuntimeError(
609+
f"Expected spark.app.name='remote-spark-test', got '{app_name}'"
610+
)
602611

603612
test_spark_transform()
604613

sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -798,11 +798,11 @@ def transform(raw_s3_data_as_df):
798798
# sys.version_info[:2] not in [(3, 9), (3, 12)],
799799
# reason=f"SageMaker Spark image only supports Python 3.9 and 3.12, got {sys.version_info[:2]}",
800800
# )
801-
@pytest.mark.skip(
802-
reason="Lake Formation credential vending (GetTemporaryGlueTableCredentials) requires "
803-
"full LF environment setup (resource registration, trust policy, data location grants) "
804-
"that is not configured in CI. See quip-amazon.com/S3FEAMMMuKm0 for details."
805-
)
801+
# @pytest.mark.skip(
802+
# reason="Lake Formation credential vending (GetTemporaryGlueTableCredentials) requires "
803+
# "full LF environment setup (resource registration, trust policy, data location grants) "
804+
# "that is not configured in CI. See quip-amazon.com/S3FEAMMMuKm0 for details."
805+
# )
806806
@pytest.mark.spark_py312
807807
@pytest.mark.slow_test
808808
def test_to_pipeline_and_execute_with_lake_formation(

0 commit comments

Comments
 (0)