Skip to content

Commit 193d16f

Browse files
committed
Add feature store tests
1 parent 534df91 commit 193d16f

File tree

17 files changed

+1802
-47
lines changed

17 files changed

+1802
-47
lines changed

sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md

Lines changed: 513 additions & 0 deletions
Large diffs are not rendered by default.

sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
"""SageMaker FeatureStore V3 - powered by sagemaker-core."""
44

55
# Resources from core
6-
from sagemaker_core.main.resources import FeatureGroup, FeatureMetadata
7-
from sagemaker_core.main.resources import FeatureStore
6+
from sagemaker.core.resources import FeatureGroup, FeatureMetadata
87

98
# Shapes from core (Pydantic - no to_dict() needed)
10-
from sagemaker_core.main.shapes import (
9+
from sagemaker.core.shapes import (
1110
DataCatalogConfig,
1211
FeatureParameter,
1312
FeatureValue,
@@ -52,7 +51,6 @@
5251
from sagemaker.mlops.feature_store.feature_utils import (
5352
as_hive_ddl,
5453
create_athena_query,
55-
create_dataset,
5654
get_session_from_role,
5755
ingest_dataframe,
5856
load_feature_definitions_from_dataframe,
@@ -76,7 +74,6 @@
7674
# Resources
7775
"FeatureGroup",
7876
"FeatureMetadata",
79-
"FeatureStore",
8077
# Shapes
8178
"DataCatalogConfig",
8279
"FeatureParameter",
@@ -113,7 +110,6 @@
113110
# Utility functions
114111
"as_hive_ddl",
115112
"create_athena_query",
116-
"create_dataset",
117113
"get_session_from_role",
118114
"ingest_dataframe",
119115
"load_feature_definitions_from_dataframe",

sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def construct_feature_group_to_be_merged(
178178
database=catalog_config.database,
179179
table_name=catalog_config.table_name,
180180
record_identifier_feature_name=record_id,
181-
event_time_identifier_feature=FeatureDefinition(event_time_name, FeatureTypeEnum(event_time_type)),
181+
event_time_identifier_feature=FeatureDefinition(
182+
feature_name=event_time_name, feature_type=FeatureTypeEnum(event_time_type).value
183+
),
182184
target_feature_name_in_base=target_feature_name_in_base,
183185
table_type=TableType.FEATURE_GROUP,
184186
feature_name_in_target=feature_name_in_target,
@@ -256,6 +258,47 @@ class DatasetBuilder:
256258
_event_time_ending_timestamp: datetime.datetime = field(default=None, init=False)
257259
_feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False)
258260

261+
@classmethod
262+
def create(
263+
cls,
264+
base: Union[FeatureGroup, pd.DataFrame],
265+
output_path: str,
266+
session: Session,
267+
record_identifier_feature_name: str = None,
268+
event_time_identifier_feature_name: str = None,
269+
included_feature_names: List[str] = None,
270+
kms_key_id: str = None,
271+
) -> "DatasetBuilder":
272+
"""Create a DatasetBuilder for generating a Dataset.
273+
274+
Args:
275+
base: A FeatureGroup or DataFrame to use as the base.
276+
output_path: S3 URI for output.
277+
session: SageMaker session.
278+
record_identifier_feature_name: Required if base is DataFrame.
279+
event_time_identifier_feature_name: Required if base is DataFrame.
280+
included_feature_names: Features to include in output.
281+
kms_key_id: KMS key for encryption.
282+
283+
Returns:
284+
DatasetBuilder instance.
285+
"""
286+
if isinstance(base, pd.DataFrame):
287+
if not record_identifier_feature_name or not event_time_identifier_feature_name:
288+
raise ValueError(
289+
"record_identifier_feature_name and event_time_identifier_feature_name "
290+
"are required when base is a DataFrame."
291+
)
292+
return cls(
293+
_sagemaker_session=session,
294+
_base=base,
295+
_output_path=output_path,
296+
_record_identifier_feature_name=record_identifier_feature_name,
297+
_event_time_identifier_feature_name=event_time_identifier_feature_name,
298+
_included_feature_names=included_feature_names,
299+
_kms_key_id=kms_key_id,
300+
)
301+
259302
def with_feature_group(
260303
self,
261304
feature_group: FeatureGroup,

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup
1414
from sagemaker.core.helper.session_helper import Session
1515
from sagemaker.core.s3.client import S3Uploader, S3Downloader
16-
from sagemaker.mlops.feature_store.dataset_builder import DatasetBuilder
1716
from sagemaker.mlops.feature_store.feature_definition import (
1817
FeatureDefinition,
1918
FractionalFeatureDefinition,
@@ -23,7 +22,7 @@
2322
)
2423
from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas
2524

26-
from sagemaker import utils
25+
from sagemaker.core.utils import unique_name_from_base
2726

2827

2928
logger = logging.getLogger(__name__)
@@ -207,7 +206,7 @@ def upload_dataframe_to_s3(
207206
Tuple of (s3_folder, temp_table_name).
208207
"""
209208

210-
temp_id = utils.unique_name_from_base("dataframe-base")
209+
temp_id = unique_name_from_base("dataframe-base")
211210
local_file = f"{temp_id}.csv"
212211
s3_folder = os.path.join(output_path, temp_id)
213212

@@ -460,29 +459,3 @@ def ingest_dataframe(
460459
manager.run(data_frame=data_frame, wait=wait, timeout=timeout)
461460
return manager
462461

463-
def create_dataset(
464-
base: Union[FeatureGroup, pd.DataFrame],
465-
output_path: str,
466-
session: Session,
467-
record_identifier_feature_name: str = None,
468-
event_time_identifier_feature_name: str = None,
469-
included_feature_names: Sequence[str] = None,
470-
kms_key_id: str = None,
471-
) -> DatasetBuilder:
472-
"""Create a DatasetBuilder for generating a Dataset."""
473-
if isinstance(base, pd.DataFrame):
474-
if not record_identifier_feature_name or not event_time_identifier_feature_name:
475-
raise ValueError(
476-
"record_identifier_feature_name and event_time_identifier_feature_name "
477-
"are required when base is a DataFrame."
478-
)
479-
return DatasetBuilder(
480-
_sagemaker_session=session,
481-
_base=base,
482-
_output_path=output_path,
483-
_record_identifier_feature_name=record_identifier_feature_name,
484-
_event_time_identifier_feature_name=event_time_identifier_feature_name,
485-
_included_feature_names=included_feature_names,
486-
_kms_key_id=kms_key_id,
487-
)
488-

sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import signal
77
from concurrent.futures import ThreadPoolExecutor, as_completed
88
from dataclasses import dataclass, field
9-
from multiprocessing.pool import AsyncResult
9+
from multiprocessing import Pool
1010
from typing import Any, Dict, Iterable, List, Sequence, Union
1111

1212
import pandas as pd
1313
from pandas import DataFrame
1414
from pandas.api.types import is_list_like
15-
from pathos.multiprocessing import ProcessingPool
1615

1716
from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup
1817
from sagemaker.core.shapes import FeatureValue
@@ -54,8 +53,8 @@ class IngestionManagerPandas:
5453
feature_definitions: Dict[str, Dict[Any, Any]]
5554
max_workers: int = 1
5655
max_processes: int = 1
57-
_async_result: AsyncResult = field(default=None, init=False)
58-
_processing_pool: ProcessingPool = field(default=None, init=False)
56+
_async_result: Any = field(default=None, init=False)
57+
_processing_pool: Pool = field(default=None, init=False)
5958
_failed_indices: List[int] = field(default_factory=list, init=False)
6059

6160
@property
@@ -100,12 +99,11 @@ def wait(self, timeout: Union[int, float] = None):
10099
results = self._async_result.get(timeout=timeout)
101100
except KeyboardInterrupt as e:
102101
self._processing_pool.terminate()
103-
self._processing_pool.close()
104-
self._processing_pool.clear()
102+
self._processing_pool.join()
105103
raise e
106104
else:
107105
self._processing_pool.close()
108-
self._processing_pool.clear()
106+
self._processing_pool.join()
109107

110108
self._failed_indices = [idx for failed in results for idx in failed]
111109

@@ -170,11 +168,10 @@ def _run_multi_process(
170168
def init_worker():
171169
signal.signal(signal.SIGINT, signal.SIG_IGN)
172170

173-
self._processing_pool = ProcessingPool(self.max_processes, init_worker)
174-
self._processing_pool.restart(force=True)
171+
self._processing_pool = Pool(self.max_processes, init_worker)
175172

176-
self._async_result = self._processing_pool.amap(
177-
lambda x: IngestionManagerPandas._run_multi_threaded(*x),
173+
self._async_result = self._processing_pool.starmap_async(
174+
IngestionManagerPandas._run_multi_threaded,
178175
args,
179176
)
180177

sagemaker-mlops/tests/__init__.py

Whitespace-only changes.

sagemaker-mlops/tests/unit/__init__.py

Whitespace-only changes.

sagemaker-mlops/tests/unit/sagemaker/__init__.py

Whitespace-only changes.

sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# Licensed under the Apache License, Version 2.0

0 commit comments

Comments
 (0)