Skip to content

Commit 0d68f0a

Browse files
committed
fix: load_feature_definitions_from_dataframe() doesn't recognize pandas nullable dtyp (5675)
1 parent 272fdbf commit 0d68f0a

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,19 @@
4646
"float64": "Fractional",
4747
}
4848

49-
_INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"}
50-
_FLOAT_TYPES = {"float_", "float16", "float32", "float64"}
49+
_INTEGER_TYPES = {
50+
"int_", "int8", "int16", "int32", "int64",
51+
"uint8", "uint16", "uint32", "uint64",
52+
# pandas nullable integer dtypes
53+
"Int8", "Int16", "Int32", "Int64",
54+
"UInt8", "UInt16", "UInt32", "UInt64",
55+
}
56+
_FLOAT_TYPES = {
57+
"float_", "float16", "float32", "float64",
58+
# pandas nullable float dtypes
59+
"Float32", "Float64",
60+
}
61+
_STRING_TYPES = {"object", "string"}
5162

5263

5364
def _get_athena_client(session: Session):

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,78 @@ def test_returns_correct_count(self, sample_dataframe):
4949
defs = load_feature_definitions_from_dataframe(sample_dataframe)
5050
assert len(defs) == 3
5151

52+
def test_infers_integral_type_with_pandas_nullable_Int64(self):
53+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="Int64")})
54+
defs = load_feature_definitions_from_dataframe(df)
55+
assert defs[0].feature_type == "Integral"
56+
57+
def test_infers_integral_type_with_pandas_nullable_Int32(self):
58+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="Int32")})
59+
defs = load_feature_definitions_from_dataframe(df)
60+
assert defs[0].feature_type == "Integral"
61+
62+
def test_infers_integral_type_with_pandas_nullable_Int16(self):
63+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="Int16")})
64+
defs = load_feature_definitions_from_dataframe(df)
65+
assert defs[0].feature_type == "Integral"
66+
67+
def test_infers_integral_type_with_pandas_nullable_Int8(self):
68+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="Int8")})
69+
defs = load_feature_definitions_from_dataframe(df)
70+
assert defs[0].feature_type == "Integral"
71+
72+
def test_infers_integral_type_with_pandas_nullable_UInt64(self):
73+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="UInt64")})
74+
defs = load_feature_definitions_from_dataframe(df)
75+
assert defs[0].feature_type == "Integral"
76+
77+
def test_infers_integral_type_with_pandas_nullable_UInt32(self):
78+
df = pd.DataFrame({"id": pd.Series([1, 2, 3], dtype="UInt32")})
79+
defs = load_feature_definitions_from_dataframe(df)
80+
assert defs[0].feature_type == "Integral"
81+
82+
def test_infers_fractional_type_with_pandas_nullable_Float64(self):
83+
df = pd.DataFrame({"value": pd.Series([1.1, 2.2, 3.3], dtype="Float64")})
84+
defs = load_feature_definitions_from_dataframe(df)
85+
assert defs[0].feature_type == "Fractional"
86+
87+
def test_infers_fractional_type_with_pandas_nullable_Float32(self):
88+
df = pd.DataFrame({"value": pd.Series([1.1, 2.2], dtype="Float32")})
89+
defs = load_feature_definitions_from_dataframe(df)
90+
assert defs[0].feature_type == "Fractional"
91+
92+
def test_infers_string_type_with_pandas_string_dtype(self):
93+
df = pd.DataFrame({"name": pd.Series(["a", "b", "c"], dtype="string")})
94+
defs = load_feature_definitions_from_dataframe(df)
95+
assert defs[0].feature_type == "String"
96+
97+
def test_infers_correct_types_after_convert_dtypes(self):
98+
df = pd.DataFrame({
99+
"id": [1, 2, 3],
100+
"price": [1.1, 2.2, 3.3],
101+
"name": ["a", "b", "c"],
102+
}).convert_dtypes()
103+
defs = load_feature_definitions_from_dataframe(df)
104+
id_def = next(d for d in defs if d.feature_name == "id")
105+
price_def = next(d for d in defs if d.feature_name == "price")
106+
name_def = next(d for d in defs if d.feature_name == "name")
107+
assert id_def.feature_type == "Integral"
108+
assert price_def.feature_type == "Fractional"
109+
assert name_def.feature_type == "String"
110+
111+
def test_infers_correct_types_with_mixed_nullable_and_numpy_dtypes(self):
112+
df = pd.DataFrame({
113+
"numpy_int": pd.Series([1, 2, 3], dtype="int64"),
114+
"nullable_float": pd.Series([1.1, 2.2, 3.3], dtype="Float64"),
115+
"nullable_int": pd.Series([10, 20, 30], dtype="Int64"),
116+
"numpy_float": pd.Series([0.1, 0.2, 0.3], dtype="float64"),
117+
})
118+
defs = load_feature_definitions_from_dataframe(df)
119+
assert next(d for d in defs if d.feature_name == "numpy_int").feature_type == "Integral"
120+
assert next(d for d in defs if d.feature_name == "nullable_float").feature_type == "Fractional"
121+
assert next(d for d in defs if d.feature_name == "nullable_int").feature_type == "Integral"
122+
assert next(d for d in defs if d.feature_name == "numpy_float").feature_type == "Fractional"
123+
52124
def test_collection_type_with_in_memory_storage(self):
53125
df = pd.DataFrame({
54126
"id": pd.Series([1, 2], dtype="int64"),

0 commit comments

Comments
 (0)