Skip to content

Commit 41c2541

Browse files
Aditi2424adishaa
andauthored
Feature store v3 (#5490)
* feat: Add Feature Store Support to V3 * Add feature store tests --------- Co-authored-by: adishaa <adishaa@amazon.com>
1 parent ad190b9 commit 41c2541

20 files changed

Lines changed: 3697 additions & 0 deletions

sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md

Lines changed: 513 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# Licensed under the Apache License, Version 2.0
3+
"""SageMaker FeatureStore V3 - powered by sagemaker-core."""
4+
5+
# Resources from core
6+
from sagemaker.core.resources import FeatureGroup, FeatureMetadata
7+
8+
# Shapes from core (Pydantic - no to_dict() needed)
9+
from sagemaker.core.shapes import (
10+
DataCatalogConfig,
11+
FeatureParameter,
12+
FeatureValue,
13+
Filter,
14+
OfflineStoreConfig,
15+
OnlineStoreConfig,
16+
OnlineStoreSecurityConfig,
17+
S3StorageConfig,
18+
SearchExpression,
19+
ThroughputConfig,
20+
TtlDuration,
21+
)
22+
23+
# Enums (local - core uses strings)
24+
from sagemaker.mlops.feature_store.inputs import (
25+
DeletionModeEnum,
26+
ExpirationTimeResponseEnum,
27+
FilterOperatorEnum,
28+
OnlineStoreStorageTypeEnum,
29+
ResourceEnum,
30+
SearchOperatorEnum,
31+
SortOrderEnum,
32+
TableFormatEnum,
33+
TargetStoreEnum,
34+
ThroughputModeEnum,
35+
)
36+
37+
# Feature Definition helpers (local)
38+
from sagemaker.mlops.feature_store.feature_definition import (
39+
FeatureDefinition,
40+
FeatureTypeEnum,
41+
CollectionTypeEnum,
42+
FractionalFeatureDefinition,
43+
IntegralFeatureDefinition,
44+
StringFeatureDefinition,
45+
ListCollectionType,
46+
SetCollectionType,
47+
VectorCollectionType,
48+
)
49+
50+
# Utility functions (local)
51+
from sagemaker.mlops.feature_store.feature_utils import (
52+
as_hive_ddl,
53+
create_athena_query,
54+
get_session_from_role,
55+
ingest_dataframe,
56+
load_feature_definitions_from_dataframe,
57+
)
58+
59+
# Classes (local)
60+
from sagemaker.mlops.feature_store.athena_query import AthenaQuery
61+
from sagemaker.mlops.feature_store.dataset_builder import (
62+
DatasetBuilder,
63+
FeatureGroupToBeMerged,
64+
JoinComparatorEnum,
65+
JoinTypeEnum,
66+
TableType,
67+
)
68+
from sagemaker.mlops.feature_store.ingestion_manager_pandas import (
69+
IngestionError,
70+
IngestionManagerPandas,
71+
)
72+
73+
__all__ = [
74+
# Resources
75+
"FeatureGroup",
76+
"FeatureMetadata",
77+
# Shapes
78+
"DataCatalogConfig",
79+
"FeatureParameter",
80+
"FeatureValue",
81+
"Filter",
82+
"OfflineStoreConfig",
83+
"OnlineStoreConfig",
84+
"OnlineStoreSecurityConfig",
85+
"S3StorageConfig",
86+
"SearchExpression",
87+
"ThroughputConfig",
88+
"TtlDuration",
89+
# Enums
90+
"DeletionModeEnum",
91+
"ExpirationTimeResponseEnum",
92+
"FilterOperatorEnum",
93+
"OnlineStoreStorageTypeEnum",
94+
"ResourceEnum",
95+
"SearchOperatorEnum",
96+
"SortOrderEnum",
97+
"TableFormatEnum",
98+
"TargetStoreEnum",
99+
"ThroughputModeEnum",
100+
# Feature Definitions
101+
"FeatureDefinition",
102+
"FeatureTypeEnum",
103+
"CollectionTypeEnum",
104+
"FractionalFeatureDefinition",
105+
"IntegralFeatureDefinition",
106+
"StringFeatureDefinition",
107+
"ListCollectionType",
108+
"SetCollectionType",
109+
"VectorCollectionType",
110+
# Utility functions
111+
"as_hive_ddl",
112+
"create_athena_query",
113+
"get_session_from_role",
114+
"ingest_dataframe",
115+
"load_feature_definitions_from_dataframe",
116+
# Classes
117+
"AthenaQuery",
118+
"DatasetBuilder",
119+
"FeatureGroupToBeMerged",
120+
"IngestionError",
121+
"IngestionManagerPandas",
122+
"JoinComparatorEnum",
123+
"JoinTypeEnum",
124+
"TableType",
125+
]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import tempfile
3+
from dataclasses import dataclass, field
4+
from typing import Any, Dict
5+
from urllib.parse import urlparse
6+
import pandas as pd
7+
from pandas import DataFrame
8+
9+
from sagemaker.mlops.feature_store.feature_utils import (
10+
start_query_execution,
11+
get_query_execution,
12+
wait_for_athena_query,
13+
download_athena_query_result,
14+
)
15+
16+
from sagemaker.core.helper.session_helper import Session
17+
18+
@dataclass
19+
class AthenaQuery:
20+
"""Class to manage querying of feature store data with AWS Athena.
21+
22+
This class instantiates a AthenaQuery object that is used to retrieve data from feature store
23+
via standard SQL queries.
24+
25+
Attributes:
26+
catalog (str): name of the data catalog.
27+
database (str): name of the database.
28+
table_name (str): name of the table.
29+
sagemaker_session (Session): instance of the Session class to perform boto calls.
30+
"""
31+
32+
catalog: str
33+
database: str
34+
table_name: str
35+
sagemaker_session: Session
36+
_current_query_execution_id: str = field(default=None, init=False)
37+
_result_bucket: str = field(default=None, init=False)
38+
_result_file_prefix: str = field(default=None, init=False)
39+
40+
def run(
41+
self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None
42+
) -> str:
43+
"""Execute a SQL query given a query string, output location and kms key.
44+
45+
This method executes the SQL query using Athena and outputs the results to output_location
46+
and returns the execution id of the query.
47+
48+
Args:
49+
query_string: SQL query string.
50+
output_location: S3 URI of the query result.
51+
kms_key: KMS key id. If set, will be used to encrypt the query result file.
52+
workgroup (str): The name of the workgroup in which the query is being started.
53+
54+
Returns:
55+
Execution id of the query.
56+
"""
57+
response = start_query_execution(
58+
session=self.sagemaker_session,
59+
catalog=self.catalog,
60+
database=self.database,
61+
query_string=query_string,
62+
output_location=output_location,
63+
kms_key=kms_key,
64+
workgroup=workgroup,
65+
)
66+
67+
self._current_query_execution_id = response["QueryExecutionId"]
68+
parsed_result = urlparse(output_location, allow_fragments=False)
69+
self._result_bucket = parsed_result.netloc
70+
self._result_file_prefix = parsed_result.path.strip("/")
71+
return self._current_query_execution_id
72+
73+
def wait(self):
74+
"""Wait for the current query to finish."""
75+
wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id)
76+
77+
def get_query_execution(self) -> Dict[str, Any]:
78+
"""Get execution status of the current query.
79+
80+
Returns:
81+
Response dict from Athena.
82+
"""
83+
return get_query_execution(self.sagemaker_session, self._current_query_execution_id)
84+
85+
def as_dataframe(self, **kwargs) -> DataFrame:
86+
"""Download the result of the current query and load it into a DataFrame.
87+
88+
Args:
89+
**kwargs (object): key arguments used for the method pandas.read_csv to be able to
90+
have a better tuning on data. For more info read:
91+
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
92+
93+
Returns:
94+
A pandas DataFrame contains the query result.
95+
"""
96+
state = self.get_query_execution()["QueryExecution"]["Status"]["State"]
97+
if state != "SUCCEEDED":
98+
if state in ("QUEUED", "RUNNING"):
99+
raise RuntimeError(f"Query {self._current_query_execution_id} still executing.")
100+
raise RuntimeError(f"Query {self._current_query_execution_id} failed.")
101+
102+
output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv")
103+
download_athena_query_result(
104+
session=self.sagemaker_session,
105+
bucket=self._result_bucket,
106+
prefix=self._result_file_prefix,
107+
query_execution_id=self._current_query_execution_id,
108+
filename=output_file,
109+
)
110+
kwargs.pop("delimiter", None)
111+
return pd.read_csv(output_file, delimiter=",", **kwargs)
112+

0 commit comments

Comments
 (0)