Skip to content

Commit 215713f

Browse files
authored
fix: load_feature_definitions_from_dataframe() doesn't recognize pandas nullable dtyp (5675) (#5732)
* fix: load_feature_definitions_from_dataframe() doesn't recognize pandas nullable dtyp (5675) * fix: address review comments (iteration #1)
1 parent b6d86f6 commit 215713f

File tree

2 files changed

+109
-2
lines changed

2 files changed

+109
-2
lines changed

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,33 @@
4444
"string": "String",
4545
"int64": "Integral",
4646
"float64": "Fractional",
47+
# pandas nullable integer dtypes
48+
"Int8": "Integral",
49+
"Int16": "Integral",
50+
"Int32": "Integral",
51+
"Int64": "Integral",
52+
"UInt8": "Integral",
53+
"UInt16": "Integral",
54+
"UInt32": "Integral",
55+
"UInt64": "Integral",
56+
# pandas nullable float dtypes
57+
"Float32": "Fractional",
58+
"Float64": "Fractional",
4759
}
4860

49-
_INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"}
50-
_FLOAT_TYPES = {"float_", "float16", "float32", "float64"}
61+
_INTEGER_TYPES = {
62+
"int_", "int8", "int16", "int32", "int64",
63+
"uint8", "uint16", "uint32", "uint64",
64+
# pandas nullable integer dtypes
65+
"Int8", "Int16", "Int32", "Int64",
66+
"UInt8", "UInt16", "UInt32", "UInt64",
67+
}
68+
_FLOAT_TYPES = {
69+
"float_", "float16", "float32", "float64",
70+
# pandas nullable float dtypes
71+
"Float32", "Float64",
72+
}
73+
_STRING_TYPES = {"object", "string"}
5174

5275

5376
def _get_athena_client(session: Session):
@@ -318,6 +341,8 @@ def _generate_feature_definition(
318341
return IntegralFeatureDefinition(series.name, collection_type)
319342
if dtype in _FLOAT_TYPES:
320343
return FractionalFeatureDefinition(series.name, collection_type)
344+
if dtype in _STRING_TYPES:
345+
return StringFeatureDefinition(series.name, collection_type)
321346
return StringFeatureDefinition(series.name, collection_type)
322347

323348

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,88 @@ def test_returns_correct_count(self, sample_dataframe):
4949
defs = load_feature_definitions_from_dataframe(sample_dataframe)
5050
assert len(defs) == 3
5151

52+
@pytest.mark.parametrize(
53+
"dtype",
54+
["Int8", "Int16", "Int32", "Int64",
55+
"UInt8", "UInt16", "UInt32", "UInt64"],
56+
)
57+
def test_infers_integral_type_with_pandas_nullable_int(
58+
self, dtype
59+
):
60+
df = pd.DataFrame(
61+
{"id": pd.Series([1, 2, 3], dtype=dtype)}
62+
)
63+
defs = load_feature_definitions_from_dataframe(df)
64+
assert defs[0].feature_type == "Integral"
65+
66+
@pytest.mark.parametrize(
67+
"dtype", ["Float32", "Float64"],
68+
)
69+
def test_infers_fractional_type_with_pandas_nullable_float(
70+
self, dtype
71+
):
72+
df = pd.DataFrame(
73+
{"value": pd.Series([1.1, 2.2, 3.3], dtype=dtype)}
74+
)
75+
defs = load_feature_definitions_from_dataframe(df)
76+
assert defs[0].feature_type == "Fractional"
77+
78+
def test_infers_string_type_with_pandas_string_dtype(self):
79+
df = pd.DataFrame({"name": pd.Series(["a", "b", "c"], dtype="string")})
80+
defs = load_feature_definitions_from_dataframe(df)
81+
assert defs[0].feature_type == "String"
82+
83+
def test_infers_correct_types_after_convert_dtypes(self):
84+
df = pd.DataFrame({
85+
"id": [1, 2, 3],
86+
"price": [1.1, 2.2, 3.3],
87+
"name": ["a", "b", "c"],
88+
}).convert_dtypes()
89+
defs = load_feature_definitions_from_dataframe(df)
90+
id_def = next(d for d in defs if d.feature_name == "id")
91+
price_def = next(d for d in defs if d.feature_name == "price")
92+
name_def = next(d for d in defs if d.feature_name == "name")
93+
assert id_def.feature_type == "Integral"
94+
assert price_def.feature_type == "Fractional"
95+
assert name_def.feature_type == "String"
96+
97+
def test_infers_correct_types_with_mixed_nullable_and_numpy_dtypes(
98+
self,
99+
):
100+
df = pd.DataFrame({
101+
"numpy_int": pd.Series([1, 2, 3], dtype="int64"),
102+
"nullable_float": pd.Series(
103+
[1.1, 2.2, 3.3], dtype="Float64"
104+
),
105+
"nullable_int": pd.Series(
106+
[10, 20, 30], dtype="Int64"
107+
),
108+
"numpy_float": pd.Series(
109+
[0.1, 0.2, 0.3], dtype="float64"
110+
),
111+
})
112+
defs = load_feature_definitions_from_dataframe(df)
113+
114+
result = next(
115+
d for d in defs if d.feature_name == "numpy_int"
116+
)
117+
assert result.feature_type == "Integral"
118+
119+
result = next(
120+
d for d in defs if d.feature_name == "nullable_float"
121+
)
122+
assert result.feature_type == "Fractional"
123+
124+
result = next(
125+
d for d in defs if d.feature_name == "nullable_int"
126+
)
127+
assert result.feature_type == "Integral"
128+
129+
result = next(
130+
d for d in defs if d.feature_name == "numpy_float"
131+
)
132+
assert result.feature_type == "Fractional"
133+
52134
def test_collection_type_with_in_memory_storage(self):
53135
df = pd.DataFrame({
54136
"id": pd.Series([1, 2], dtype="int64"),

0 commit comments

Comments
 (0)