1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15+ import glob
1516import logging
1617import os
1718import subprocess
1819import sys
20+ import tempfile
1921import time
2022from typing import Dict
2123from datetime import datetime
@@ -121,6 +123,16 @@ def sagemaker_session():
121123 return Session ()
122124
123125
126+ @pytest .fixture (scope = "module" )
127+ def pre_execution_commands (sagemaker_session ):
128+ return get_pre_execution_commands (sagemaker_session = sagemaker_session )
129+
130+
131+ @pytest .fixture (scope = "module" )
132+ def dependencies_path ():
133+ return os .path .join (_FEATURE_PROCESSOR_DIR , "requirements.txt" )
134+
135+
124136@pytest .mark .slow_test
125137def test_feature_processor_transform_online_only_store_ingestion (
126138 sagemaker_session ,
@@ -137,8 +149,6 @@ def test_feature_processor_transform_online_only_store_ingestion(
137149
138150 raw_data_uri = get_raw_car_data_s3_uri (sagemaker_session = sagemaker_session )
139151
140- print ("About to apply @feature_processor decorator..." )
141-
142152 @feature_processor (
143153 inputs = [CSVDataSource (raw_data_uri )],
144154 output = feature_groups ["car_data_arn" ],
@@ -175,9 +185,8 @@ def transform(raw_s3_data_as_df):
175185 transformed_df .show ()
176186 return transformed_df
177187
178- print ("Decorator applied. About to call transform()..." )
179- transform ()
180- print ("transform() completed." )
188+ # this calls spark 3.3 which requires java 11
189+ transform ()
181190
182191 featurestore_client = sagemaker_session .sagemaker_featurestore_runtime_client
183192 results = featurestore_client .batch_get_record (
@@ -496,7 +505,10 @@ def transform(raw_s3_data_as_df):
496505 columns = ["ingest_time" , "write_time" , "api_invocation_time" , "is_deleted" ]
497506 )
498507
499- assert dataset .equals (get_expected_dataframe ())
508+ expected = get_expected_dataframe ()
509+ dataset_sorted = dataset .sort_values (by = "id" ).reset_index (drop = True )
510+ expected_sorted = expected .sort_values (by = "id" ).reset_index (drop = True )
511+ assert dataset_sorted .equals (expected_sorted )
500512 finally :
501513 cleanup_offline_store (
502514 feature_group = feature_groups ["car_data_feature_group" ],
@@ -521,6 +533,8 @@ def transform(raw_s3_data_as_df):
521533)
522534def test_feature_processor_transform_offline_only_store_ingestion_run_with_remote (
523535 sagemaker_session ,
536+ pre_execution_commands ,
537+ dependencies_path ,
524538):
525539 car_data_feature_group_name = get_car_data_feature_group_name ()
526540 car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name ()
@@ -533,16 +547,10 @@ def test_feature_processor_transform_offline_only_store_ingestion_run_with_remot
533547 )
534548
535549 raw_data_uri = get_raw_car_data_s3_uri (sagemaker_session = sagemaker_session )
536- whl_file_uri = get_wheel_file_s3_uri (sagemaker_session = sagemaker_session )
537- whl_file_name = os .path .basename (whl_file_uri )
538-
539- pre_execution_commands = [
540- f"aws s3 cp { whl_file_uri } ./" ,
541- f"/usr/local/bin/python3.9 -m pip install ./{ whl_file_name } --force-reinstall" ,
542- ]
543550
544551 @remote (
545552 pre_execution_commands = pre_execution_commands ,
553+ dependencies = dependencies_path ,
546554 spark_config = SparkConfig (),
547555 instance_type = "ml.m5.xlarge" ,
548556 )
@@ -637,7 +645,10 @@ def transform(raw_s3_data_as_df):
637645 columns = ["ingest_time" , "write_time" , "api_invocation_time" , "is_deleted" ]
638646 )
639647
640- assert dataset .equals (get_expected_dataframe ())
648+ expected = get_expected_dataframe ()
649+ dataset_sorted = dataset .sort_values (by = "id" ).reset_index (drop = True )
650+ expected_sorted = expected .sort_values (by = "id" ).reset_index (drop = True )
651+ assert dataset_sorted .equals (expected_sorted )
641652 finally :
642653 cleanup_offline_store (
643654 feature_group = feature_groups ["car_data_feature_group" ],
@@ -662,6 +673,8 @@ def transform(raw_s3_data_as_df):
662673)
663674def test_to_pipeline_and_execute (
664675 sagemaker_session ,
676+ pre_execution_commands ,
677+ dependencies_path ,
665678):
666679 pipeline_name = "pipeline-name-01"
667680 car_data_feature_group_name = get_car_data_feature_group_name ()
@@ -675,16 +688,10 @@ def test_to_pipeline_and_execute(
675688 )
676689
677690 raw_data_uri = get_raw_car_data_s3_uri (sagemaker_session = sagemaker_session )
678- whl_file_uri = get_wheel_file_s3_uri (sagemaker_session = sagemaker_session )
679- whl_file_name = os .path .basename (whl_file_uri )
680-
681- pre_execution_commands = [
682- f"aws s3 cp { whl_file_uri } ./" ,
683- f"/usr/local/bin/python3.9 -m pip install ./{ whl_file_name } --force-reinstall" ,
684- ]
685691
686692 @remote (
687693 pre_execution_commands = pre_execution_commands ,
694+ dependencies = dependencies_path ,
688695 spark_config = SparkConfig (),
689696 instance_type = "ml.m5.xlarge" ,
690697 )
@@ -789,6 +796,8 @@ def transform(raw_s3_data_as_df):
789796)
790797def test_schedule_and_event_trigger (
791798 sagemaker_session ,
799+ pre_execution_commands ,
800+ dependencies_path ,
792801):
793802 pipeline_name = "pipeline-name-01"
794803 car_data_feature_group_name = get_car_data_feature_group_name ()
@@ -802,16 +811,10 @@ def test_schedule_and_event_trigger(
802811 )
803812
804813 raw_data_uri = get_raw_car_data_s3_uri (sagemaker_session = sagemaker_session )
805- whl_file_uri = get_wheel_file_s3_uri (sagemaker_session = sagemaker_session )
806- whl_file_name = os .path .basename (whl_file_uri )
807-
808- pre_execution_commands = [
809- f"aws s3 cp { whl_file_uri } ./" ,
810- f"/usr/local/bin/python3.9 -m pip install ./{ whl_file_name } --force-reinstall" ,
811- ]
812814
813815 @remote (
814816 pre_execution_commands = pre_execution_commands ,
817+ dependencies = dependencies_path ,
815818 spark_config = SparkConfig (),
816819 instance_type = "ml.m5.xlarge" ,
817820 )
@@ -1042,7 +1045,6 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str:
10421045 "feature-processor-test" ,
10431046 "csv-data" ,
10441047 )
1045- print ("About to upload raw car data to S3..." )
10461048 raw_car_data_s3_uri = S3Uploader .upload (
10471049 os .path .join (_FEATURE_PROCESSOR_DIR , "car-data.csv" ),
10481050 uri ,
@@ -1052,18 +1054,36 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str:
10521054 return raw_car_data_s3_uri
10531055
10541056
1055- def get_wheel_file_s3_uri (sagemaker_session ) -> str :
1056- uri = "s3://{}/{}/wheel-file" .format (
1057+ def get_wheel_file_s3_uri (sagemaker_session ):
1058+ """Upload all SDK wheels to S3 and return (s3_prefix, wheel_basenames).
1059+
1060+ Returns:
1061+ tuple: (s3_prefix, [sagemaker_whl, core_whl, mlops_whl]) where each
1062+ element is the basename of the corresponding wheel file.
1063+ """
1064+ s3_prefix = "s3://{}/{}/wheel-file" .format (
10571065 sagemaker_session .default_bucket (), "feature-processor-test"
10581066 )
1059- source = _generate_and_move_sagemaker_sdk_tar ()
1060- print (source )
1061- raw_car_data_s3_uri = S3Uploader .upload (
1062- source ,
1063- uri ,
1064- sagemaker_session = sagemaker_session ,
1065- )
1066- return raw_car_data_s3_uri
1067+ sources = _generate_and_move_sagemaker_sdk_tar ()
1068+ for source in sources :
1069+ print (source )
1070+ S3Uploader .upload (source , s3_prefix , sagemaker_session = sagemaker_session )
1071+ wheel_names = [os .path .basename (s ) for s in sources ]
1072+ return s3_prefix , wheel_names
1073+
1074+
1075+ def get_pre_execution_commands (sagemaker_session ):
1076+ """Build SDK wheels, upload to S3, and return pre-execution install commands."""
1077+ s3_prefix , wheel_names = get_wheel_file_s3_uri (sagemaker_session = sagemaker_session )
1078+ sagemaker_whl , core_whl , mlops_whl = wheel_names
1079+ print (f'{ sagemaker_whl = } , { core_whl = } , { mlops_whl } ' )
1080+ return [
1081+ f"aws s3 cp { s3_prefix } / /tmp/packages/ --recursive" ,
1082+ "pip3 install 'setuptools<75'" ,
1083+ f"pip3 install --no-build-isolation '/tmp/packages/{ sagemaker_whl } [feature-processor]' 'numpy<2.0.0' 'ml_dtypes<=0.4.1' 'setuptools<75' || true" ,
1084+ f"pip3 install --no-deps --force-reinstall /tmp/packages/{ sagemaker_whl } " ,
1085+ f"pip3 install --no-deps --force-reinstall /tmp/packages/{ core_whl } /tmp/packages/{ mlops_whl } " ,
1086+ ]
10671087
10681088
10691089def create_feature_groups (
@@ -1170,16 +1190,7 @@ def get_expected_dataframe():
11701190
11711191
11721192def _wait_for_feature_group_create (feature_group : FeatureGroup ):
1173- status = feature_group .feature_group_status
1174- while status == "Creating" :
1175- print ("Waiting for Feature Group Creation" )
1176- time .sleep (5 )
1177- feature_group .refresh ()
1178- status = feature_group .feature_group_status
1179- if status != "Created" :
1180- print (f"FeatureGroup { feature_group .feature_group_name } status: { status } " )
1181- raise RuntimeError (f"Failed to create feature group { feature_group .feature_group_name } " )
1182- print (f"FeatureGroup { feature_group .feature_group_name } successfully created." )
1193+ feature_group .wait_for_status (target_status = "Created" , poll = 5 )
11831194
11841195
11851196def _wait_for_pipeline_execution_to_stop (pipeline_execution_arn : str , sagemaker_client : client ):
@@ -1304,16 +1315,41 @@ def get_sagemaker_client(sagemaker_session=Session) -> client:
13041315
13051316
13061317def _generate_and_move_sagemaker_sdk_tar ():
1307- """
1308- Run setup.py sdist to generate the PySDK whl file
1309- """
1310- repo_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." , ".." , ".." , ".." ))
1311- subprocess .run ("python -m build --wheel" , shell = True , cwd = repo_root , check = True )
1318+ """Build all three SDK wheel files and return their paths."""
1319+ repo_root = os .path .abspath (
1320+ os .path .join (os .path .dirname (__file__ ), ".." , ".." , ".." , ".." , ".." )
1321+ )
13121322 dist_dir = os .path .join (repo_root , "dist" )
1313- source_archive = os .listdir (dist_dir )[0 ]
1314- source_path = os .path .join (dist_dir , source_archive )
13151323
1316- return source_path
1324+ # Build wheels for all three sub-packages into the shared dist/ directory
1325+ build_dirs = [
1326+ repo_root ,
1327+ os .path .join (repo_root , "sagemaker-core" ),
1328+ os .path .join (repo_root , "sagemaker-mlops" ),
1329+ ]
1330+ for build_dir in build_dirs :
1331+ subprocess .run (
1332+ f"python -m build --wheel --outdir { dist_dir } " ,
1333+ shell = True ,
1334+ cwd = build_dir ,
1335+ check = True ,
1336+ )
1337+
1338+ # Locate the three expected wheels by prefix pattern
1339+ wheel_patterns = [
1340+ "sagemaker-[0-9]*.whl" ,
1341+ "sagemaker_core-*.whl" ,
1342+ "sagemaker_mlops-*.whl" ,
1343+ ]
1344+ paths = []
1345+ for pattern in wheel_patterns :
1346+ matches = glob .glob (os .path .join (dist_dir , pattern ))
1347+ if not matches :
1348+ raise FileNotFoundError (
1349+ f"No wheel found matching { pattern } in { dist_dir } "
1350+ )
1351+ paths .append (matches [0 ])
1352+ return paths
13171353
13181354
13191355def _wait_for_feature_group_lineage_contexts (
0 commit comments