Skip to content

Commit b8a776b

Browse files
committed
fix: Fix flaky integ tests
1 parent f2b8283 commit b8a776b

File tree

2 files changed

+94
-57
lines changed

2 files changed

+94
-57
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# unrelased sagemaker is installed via pre_execution_commands

sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import glob
1516
import logging
1617
import os
1718
import subprocess
1819
import sys
20+
import tempfile
1921
import time
2022
from typing import Dict
2123
from datetime import datetime
@@ -121,6 +123,16 @@ def sagemaker_session():
121123
return Session()
122124

123125

126+
@pytest.fixture(scope="module")
127+
def pre_execution_commands(sagemaker_session):
128+
return get_pre_execution_commands(sagemaker_session=sagemaker_session)
129+
130+
131+
@pytest.fixture(scope="module")
132+
def dependencies_path():
133+
return os.path.join(_FEATURE_PROCESSOR_DIR, "requirements.txt")
134+
135+
124136
@pytest.mark.slow_test
125137
def test_feature_processor_transform_online_only_store_ingestion(
126138
sagemaker_session,
@@ -137,8 +149,6 @@ def test_feature_processor_transform_online_only_store_ingestion(
137149

138150
raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session)
139151

140-
print("About to apply @feature_processor decorator...")
141-
142152
@feature_processor(
143153
inputs=[CSVDataSource(raw_data_uri)],
144154
output=feature_groups["car_data_arn"],
@@ -175,9 +185,8 @@ def transform(raw_s3_data_as_df):
175185
transformed_df.show()
176186
return transformed_df
177187

178-
print("Decorator applied. About to call transform()...")
179-
transform()
180-
print("transform() completed.")
188+
# this calls spark 3.3 which requires java 11
189+
transform()
181190

182191
featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client
183192
results = featurestore_client.batch_get_record(
@@ -496,7 +505,10 @@ def transform(raw_s3_data_as_df):
496505
columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"]
497506
)
498507

499-
assert dataset.equals(get_expected_dataframe())
508+
expected = get_expected_dataframe()
509+
dataset_sorted = dataset.sort_values(by="id").reset_index(drop=True)
510+
expected_sorted = expected.sort_values(by="id").reset_index(drop=True)
511+
assert dataset_sorted.equals(expected_sorted)
500512
finally:
501513
cleanup_offline_store(
502514
feature_group=feature_groups["car_data_feature_group"],
@@ -521,6 +533,8 @@ def transform(raw_s3_data_as_df):
521533
)
522534
def test_feature_processor_transform_offline_only_store_ingestion_run_with_remote(
523535
sagemaker_session,
536+
pre_execution_commands,
537+
dependencies_path,
524538
):
525539
car_data_feature_group_name = get_car_data_feature_group_name()
526540
car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name()
@@ -533,16 +547,10 @@ def test_feature_processor_transform_offline_only_store_ingestion_run_with_remot
533547
)
534548

535549
raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session)
536-
whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session)
537-
whl_file_name = os.path.basename(whl_file_uri)
538-
539-
pre_execution_commands = [
540-
f"aws s3 cp {whl_file_uri} ./",
541-
f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall",
542-
]
543550

544551
@remote(
545552
pre_execution_commands=pre_execution_commands,
553+
dependencies=dependencies_path,
546554
spark_config=SparkConfig(),
547555
instance_type="ml.m5.xlarge",
548556
)
@@ -637,7 +645,10 @@ def transform(raw_s3_data_as_df):
637645
columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"]
638646
)
639647

640-
assert dataset.equals(get_expected_dataframe())
648+
expected = get_expected_dataframe()
649+
dataset_sorted = dataset.sort_values(by="id").reset_index(drop=True)
650+
expected_sorted = expected.sort_values(by="id").reset_index(drop=True)
651+
assert dataset_sorted.equals(expected_sorted)
641652
finally:
642653
cleanup_offline_store(
643654
feature_group=feature_groups["car_data_feature_group"],
@@ -662,6 +673,8 @@ def transform(raw_s3_data_as_df):
662673
)
663674
def test_to_pipeline_and_execute(
664675
sagemaker_session,
676+
pre_execution_commands,
677+
dependencies_path,
665678
):
666679
pipeline_name = "pipeline-name-01"
667680
car_data_feature_group_name = get_car_data_feature_group_name()
@@ -675,16 +688,10 @@ def test_to_pipeline_and_execute(
675688
)
676689

677690
raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session)
678-
whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session)
679-
whl_file_name = os.path.basename(whl_file_uri)
680-
681-
pre_execution_commands = [
682-
f"aws s3 cp {whl_file_uri} ./",
683-
f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall",
684-
]
685691

686692
@remote(
687693
pre_execution_commands=pre_execution_commands,
694+
dependencies=dependencies_path,
688695
spark_config=SparkConfig(),
689696
instance_type="ml.m5.xlarge",
690697
)
@@ -789,6 +796,8 @@ def transform(raw_s3_data_as_df):
789796
)
790797
def test_schedule_and_event_trigger(
791798
sagemaker_session,
799+
pre_execution_commands,
800+
dependencies_path,
792801
):
793802
pipeline_name = "pipeline-name-01"
794803
car_data_feature_group_name = get_car_data_feature_group_name()
@@ -802,16 +811,10 @@ def test_schedule_and_event_trigger(
802811
)
803812

804813
raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session)
805-
whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session)
806-
whl_file_name = os.path.basename(whl_file_uri)
807-
808-
pre_execution_commands = [
809-
f"aws s3 cp {whl_file_uri} ./",
810-
f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall",
811-
]
812814

813815
@remote(
814816
pre_execution_commands=pre_execution_commands,
817+
dependencies=dependencies_path,
815818
spark_config=SparkConfig(),
816819
instance_type="ml.m5.xlarge",
817820
)
@@ -1042,7 +1045,6 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str:
10421045
"feature-processor-test",
10431046
"csv-data",
10441047
)
1045-
print("About to upload raw car data to S3...")
10461048
raw_car_data_s3_uri = S3Uploader.upload(
10471049
os.path.join(_FEATURE_PROCESSOR_DIR, "car-data.csv"),
10481050
uri,
@@ -1052,18 +1054,36 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str:
10521054
return raw_car_data_s3_uri
10531055

10541056

1055-
def get_wheel_file_s3_uri(sagemaker_session) -> str:
1056-
uri = "s3://{}/{}/wheel-file".format(
1057+
def get_wheel_file_s3_uri(sagemaker_session):
1058+
"""Upload all SDK wheels to S3 and return (s3_prefix, wheel_basenames).
1059+
1060+
Returns:
1061+
tuple: (s3_prefix, [sagemaker_whl, core_whl, mlops_whl]) where each
1062+
element is the basename of the corresponding wheel file.
1063+
"""
1064+
s3_prefix = "s3://{}/{}/wheel-file".format(
10571065
sagemaker_session.default_bucket(), "feature-processor-test"
10581066
)
1059-
source = _generate_and_move_sagemaker_sdk_tar()
1060-
print(source)
1061-
raw_car_data_s3_uri = S3Uploader.upload(
1062-
source,
1063-
uri,
1064-
sagemaker_session=sagemaker_session,
1065-
)
1066-
return raw_car_data_s3_uri
1067+
sources = _generate_and_move_sagemaker_sdk_tar()
1068+
for source in sources:
1069+
print(source)
1070+
S3Uploader.upload(source, s3_prefix, sagemaker_session=sagemaker_session)
1071+
wheel_names = [os.path.basename(s) for s in sources]
1072+
return s3_prefix, wheel_names
1073+
1074+
1075+
def get_pre_execution_commands(sagemaker_session):
1076+
"""Build SDK wheels, upload to S3, and return pre-execution install commands."""
1077+
s3_prefix, wheel_names = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session)
1078+
sagemaker_whl, core_whl, mlops_whl = wheel_names
1079+
print(f'{sagemaker_whl=}, {core_whl=}, {mlops_whl}')
1080+
return [
1081+
f"aws s3 cp {s3_prefix}/ /tmp/packages/ --recursive",
1082+
"pip3 install 'setuptools<75'",
1083+
f"pip3 install --no-build-isolation '/tmp/packages/{sagemaker_whl}[feature-processor]' 'numpy<2.0.0' 'ml_dtypes<=0.4.1' 'setuptools<75' || true",
1084+
f"pip3 install --no-deps --force-reinstall /tmp/packages/{sagemaker_whl}",
1085+
f"pip3 install --no-deps --force-reinstall /tmp/packages/{core_whl} /tmp/packages/{mlops_whl}",
1086+
]
10671087

10681088

10691089
def create_feature_groups(
@@ -1170,16 +1190,7 @@ def get_expected_dataframe():
11701190

11711191

11721192
def _wait_for_feature_group_create(feature_group: FeatureGroup):
1173-
status = feature_group.feature_group_status
1174-
while status == "Creating":
1175-
print("Waiting for Feature Group Creation")
1176-
time.sleep(5)
1177-
feature_group.refresh()
1178-
status = feature_group.feature_group_status
1179-
if status != "Created":
1180-
print(f"FeatureGroup {feature_group.feature_group_name} status: {status}")
1181-
raise RuntimeError(f"Failed to create feature group {feature_group.feature_group_name}")
1182-
print(f"FeatureGroup {feature_group.feature_group_name} successfully created.")
1193+
feature_group.wait_for_status(target_status="Created", poll=5)
11831194

11841195

11851196
def _wait_for_pipeline_execution_to_stop(pipeline_execution_arn: str, sagemaker_client: client):
@@ -1304,16 +1315,41 @@ def get_sagemaker_client(sagemaker_session=Session) -> client:
13041315

13051316

13061317
def _generate_and_move_sagemaker_sdk_tar():
1307-
"""
1308-
Run setup.py sdist to generate the PySDK whl file
1309-
"""
1310-
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", ".."))
1311-
subprocess.run("python -m build --wheel", shell=True, cwd=repo_root, check=True)
1318+
"""Build all three SDK wheel files and return their paths."""
1319+
repo_root = os.path.abspath(
1320+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..")
1321+
)
13121322
dist_dir = os.path.join(repo_root, "dist")
1313-
source_archive = os.listdir(dist_dir)[0]
1314-
source_path = os.path.join(dist_dir, source_archive)
13151323

1316-
return source_path
1324+
# Build wheels for all three sub-packages into the shared dist/ directory
1325+
build_dirs = [
1326+
repo_root,
1327+
os.path.join(repo_root, "sagemaker-core"),
1328+
os.path.join(repo_root, "sagemaker-mlops"),
1329+
]
1330+
for build_dir in build_dirs:
1331+
subprocess.run(
1332+
f"python -m build --wheel --outdir {dist_dir}",
1333+
shell=True,
1334+
cwd=build_dir,
1335+
check=True,
1336+
)
1337+
1338+
# Locate the three expected wheels by prefix pattern
1339+
wheel_patterns = [
1340+
"sagemaker-[0-9]*.whl",
1341+
"sagemaker_core-*.whl",
1342+
"sagemaker_mlops-*.whl",
1343+
]
1344+
paths = []
1345+
for pattern in wheel_patterns:
1346+
matches = glob.glob(os.path.join(dist_dir, pattern))
1347+
if not matches:
1348+
raise FileNotFoundError(
1349+
f"No wheel found matching {pattern} in {dist_dir}"
1350+
)
1351+
paths.append(matches[0])
1352+
return paths
13171353

13181354

13191355
def _wait_for_feature_group_lineage_contexts(

0 commit comments

Comments
 (0)