Skip to content

Commit c7f1faa

Browse files
committed
Add deprecation path for the hierachical schema specification in sklearn
1 parent 5ae4346 commit c7f1faa

File tree

2 files changed

+170
-8
lines changed

2 files changed

+170
-8
lines changed

khiops/sklearn/dataset.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import khiops.core as kh
2222
import khiops.core.internals.filesystems as fs
2323
from khiops.core.dictionary import VariableBlock
24-
from khiops.core.internals.common import is_dict_like, is_list_like, type_error_message
24+
from khiops.core.internals.common import (
25+
deprecation_message,
26+
is_dict_like,
27+
is_list_like,
28+
type_error_message,
29+
)
2530

2631
# Disable PEP8 variable names because of scikit-learn X,y conventions
2732
# To capture invalid-names other than X,y run:
@@ -171,6 +176,54 @@ def _check_multitable_spec(ds_spec):
171176
)
172177

173178

179+
def _table_name_of_path(table_path):
180+
return table_path.split("/")[-1]
181+
182+
183+
def _upgrade_mapping_spec(ds_spec):
184+
assert is_dict_like(ds_spec)
185+
new_ds_spec = {}
186+
new_ds_spec["additional_data_tables"] = {}
187+
for table_name, table_data in ds_spec["tables"].items():
188+
table_df, table_key = table_data
189+
if not is_list_like(table_key):
190+
table_key = [table_key]
191+
if table_name == ds_spec["main_table"]:
192+
new_ds_spec["main_table"] = (table_df, table_key)
193+
else:
194+
table_path = [table_name]
195+
is_entity = False
196+
197+
# Cycle 4 times on the relations to get all transitive relation, like:
198+
# - current table name N
199+
# - main table name N1
200+
# - and relations: (N1, N2), (N2, N3), (N3, N)
201+
# the data-path must be N2/N3/N
202+
# Note: this is a heuristic that should be replaced with a graph
203+
# traversal procedure
204+
# If no "relations" key exists, then one has a star schema and
205+
# the data-paths are the names of the secondary tables themselves
206+
# (with respect to the main table)
207+
if "relations" in ds_spec:
208+
for relation in list(ds_spec["relations"]) * 4:
209+
left, right = relation[:2]
210+
if len(relation) == 3 and right == table_name:
211+
is_entity = relation[2]
212+
if (
213+
left != ds_spec["main_table"]
214+
and left not in table_path
215+
and right in table_path
216+
):
217+
table_path.insert(0, left)
218+
table_path = "/".join(table_path)
219+
if is_entity:
220+
table_data = (table_df, table_key, is_entity)
221+
else:
222+
table_data = (table_df, table_key)
223+
new_ds_spec["additional_data_tables"][table_path] = table_data
224+
return new_ds_spec
225+
226+
174227
def get_khiops_type(numpy_type):
175228
"""Translates a numpy dtype to a Khiops dictionary type
176229
@@ -418,14 +471,26 @@ def _check_input_sequence(self, X, key=None):
418471
# Check the key for the main_table (it is the same for the others)
419472
_check_table_key("main_table", key)
420473

421-
def _table_name_of_path(self, table_path):
422-
# TODO: Add >= 128-character truncation and indexing scheme
423-
return table_path.split("/")[-1]
424-
425474
def _init_tables_from_mapping(self, X):
426475
"""Initializes the table spec from a dict-like 'X'"""
427476
assert is_dict_like(X), "'X' must be dict-like"
428477

478+
# Detect if deprecated mapping specification syntax is used;
479+
# if so, issue deprecation warning and transform it to the new syntax
480+
if "tables" in X.keys() and isinstance(X.get("main_table"), str):
481+
warnings.warn(
482+
deprecation_message(
483+
"This multi-table dataset specification format",
484+
"11.0.1",
485+
replacement=(
486+
"the new data-path-based format, as documented in "
487+
":doc:`multi_table_primer`."
488+
),
489+
quote=False,
490+
)
491+
)
492+
X = _upgrade_mapping_spec(X)
493+
429494
# Check the input mapping
430495
check_dataset_spec(X)
431496

@@ -444,7 +509,7 @@ def _init_tables_from_mapping(self, X):
444509
if "additional_data_tables" in X:
445510
for table_path, table_spec in X["additional_data_tables"].items():
446511
table_source, table_key = table_spec[:2]
447-
table_name = self._table_name_of_path(table_path)
512+
table_name = _table_name_of_path(table_path)
448513
table = PandasTable(
449514
table_name,
450515
table_source,
@@ -461,7 +526,7 @@ def _init_tables_from_mapping(self, X):
461526
parent_table_name = self.main_table.name
462527
else:
463528
table_path_fragments = table_path.split("/")
464-
parent_table_name = self._table_name_of_path(
529+
parent_table_name = _table_name_of_path(
465530
"/".join(table_path_fragments[:-1])
466531
)
467532
self.relations.append(

tests/test_dataset_class.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import shutil
1010
import unittest
11+
import warnings
1112

1213
import numpy as np
1314
import pandas as pd
@@ -16,7 +17,7 @@
1617
from pandas.testing import assert_frame_equal
1718
from sklearn import datasets
1819

19-
from khiops.sklearn.dataset import Dataset
20+
from khiops.sklearn.dataset import Dataset, _upgrade_mapping_spec
2021

2122

2223
class DatasetInputOutputConsistencyTests(unittest.TestCase):
@@ -353,6 +354,102 @@ def get_ref_var_types(self, multitable, schema=None):
353354

354355
return ref_var_types
355356

357+
def test_dataset_of_deprecated_mt_mapping(self):
358+
"""Test deprecated multi-table specification handling"""
359+
(
360+
ref_main_table,
361+
ref_secondary_table_1,
362+
ref_secondary_table_2,
363+
ref_tertiary_table,
364+
ref_quaternary_table,
365+
) = self.create_multitable_snowflake_dataframes()
366+
367+
features_ref_main_table = ref_main_table.drop("class", axis=1)
368+
expected_ds_spec = {
369+
"main_table": (features_ref_main_table, ["User_ID"]),
370+
"additional_data_tables": {
371+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"], False),
372+
"B/D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"], False),
373+
"B/D/E": (
374+
ref_quaternary_table,
375+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
376+
),
377+
"C": (ref_secondary_table_2, ["User_ID"], True),
378+
},
379+
}
380+
deprecated_ds_spec = {
381+
"main_table": "A",
382+
"tables": {
383+
"A": (features_ref_main_table, "User_ID"),
384+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"]),
385+
"C": (ref_secondary_table_2, "User_ID"),
386+
"D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"]),
387+
"E": (
388+
ref_quaternary_table,
389+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
390+
),
391+
},
392+
"relations": {
393+
("A", "B", False),
394+
("B", "D", False),
395+
("D", "E"),
396+
("A", "C", True),
397+
},
398+
}
399+
400+
label = ref_main_table["class"]
401+
402+
# Test that deprecation warning is issued when creating a dataset
403+
# according to the deprecated spec
404+
with warnings.catch_warnings(record=True) as warning_list:
405+
_ = Dataset(deprecated_ds_spec, label)
406+
self.assertTrue(len(warning_list) > 0)
407+
deprecation_warning_found = False
408+
for warning in warning_list:
409+
warning_message = warning.message
410+
if (
411+
issubclass(warning.category, UserWarning)
412+
and len(warning_message.args) == 1
413+
and "multi-table dataset specification format"
414+
in warning_message.args[0]
415+
and "deprecated" in warning_message.args[0]
416+
):
417+
deprecation_warning_found = True
418+
break
419+
self.assertTrue(deprecation_warning_found)
420+
421+
# Test that a deprecated dataset spec is upgraded to the new format
422+
ds_spec = _upgrade_mapping_spec(deprecated_ds_spec)
423+
self.assertEqual(ds_spec.keys(), expected_ds_spec.keys())
424+
main_table = ds_spec["main_table"]
425+
expected_main_table = expected_ds_spec["main_table"]
426+
427+
# Test that main table keys are identical
428+
self.assertEqual(main_table[1], expected_main_table[1])
429+
430+
# Test that main table data frame are equal
431+
assert_frame_equal(main_table[0], expected_main_table[0])
432+
433+
# Test that additional data tables keys are identical
434+
additional_data_tables = ds_spec["additional_data_tables"]
435+
expected_additional_data_tables = expected_ds_spec["additional_data_tables"]
436+
self.assertEqual(
437+
additional_data_tables.keys(), expected_additional_data_tables.keys()
438+
)
439+
440+
for table_path, expected_table_data in expected_additional_data_tables.items():
441+
table_data = additional_data_tables[table_path]
442+
443+
# Test that secondary table keys are identical
444+
self.assertEqual(table_data[1], expected_table_data[1])
445+
446+
# Test that the secondary table data frames are identical
447+
assert_frame_equal(table_data[0], expected_table_data[0])
448+
449+
# Test that the secondary table entity statuses are identical if True
450+
if len(expected_table_data) > 2 and expected_table_data[2] is True:
451+
self.assertEqual(table_data[2], expected_table_data[2])
452+
356453
def test_dataset_is_correctly_built(self):
357454
"""Test that the dataset structure is consistent with the input spec"""
358455
ds_spec, label = self.create_fixture_ds_spec(

0 commit comments

Comments
 (0)