Skip to content

Commit 3fc305d

Browse files
committed
Add deprecation path for the hierachical schema specification in sklearn
1 parent 6684faa commit 3fc305d

File tree

2 files changed

+169
-8
lines changed

2 files changed

+169
-8
lines changed

khiops/sklearn/dataset.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from abc import ABC, abstractmethod
1212
from collections.abc import Iterable, Mapping, Sequence
13+
from itertools import cycle, islice
1314

1415
import numpy as np
1516
import pandas as pd
@@ -21,7 +22,12 @@
2122
import khiops.core as kh
2223
import khiops.core.internals.filesystems as fs
2324
from khiops.core.dictionary import VariableBlock
24-
from khiops.core.internals.common import is_dict_like, is_list_like, type_error_message
25+
from khiops.core.internals.common import (
26+
deprecation_message,
27+
is_dict_like,
28+
is_list_like,
29+
type_error_message,
30+
)
2531

2632
# Disable PEP8 variable names because of scikit-learn X,y conventions
2733
# To capture invalid-names other than X,y run:
@@ -171,6 +177,52 @@ def _check_multitable_spec(ds_spec):
171177
)
172178

173179

180+
def _table_name_of_path(table_path):
181+
return table_path.split("/")[-1]
182+
183+
184+
def _upgrade_mapping_spec(X):
185+
assert is_dict_like(X)
186+
new_X = {}
187+
new_X["additional_data_tables"] = {}
188+
for table_name, table_data in X["tables"].items():
189+
table_df, table_key = table_data
190+
if not is_list_like(table_key):
191+
table_key = [table_key]
192+
if table_name == X["main_table"]:
193+
new_X["main_table"] = (table_df, table_key)
194+
else:
195+
table_path = [table_name]
196+
is_entity = False
197+
198+
# Cycle twice on the relations to get all transitive relation, like:
199+
# - current table name N
200+
# - main table name N1
201+
# - and relations: (N1, N2), (N2, N3), (N3, N)
202+
# the data-path must be N2/N3/N
203+
# If no "relations" key exists, then one has a star schema and
204+
# the data-paths are the names of the secondary tables themselves
205+
# (with respect to the main table)
206+
if "relations" in X:
207+
for relation in islice(cycle(X["relations"]), 2 * len(X["relations"])):
208+
left, right = relation[:2]
209+
if len(relation) == 3 and right == table_name:
210+
is_entity = relation[2]
211+
if (
212+
left != X["main_table"]
213+
and left not in table_path
214+
and right in table_path
215+
):
216+
table_path.insert(0, left)
217+
table_path = "/".join(table_path)
218+
if is_entity:
219+
table_data = (table_df, table_key, is_entity)
220+
else:
221+
table_data = (table_df, table_key)
222+
new_X["additional_data_tables"][table_path] = table_data
223+
return new_X
224+
225+
174226
def get_khiops_type(numpy_type):
175227
"""Translates a numpy dtype to a Khiops dictionary type
176228
@@ -418,14 +470,26 @@ def _check_input_sequence(self, X, key=None):
418470
# Check the key for the main_table (it is the same for the others)
419471
_check_table_key("main_table", key)
420472

421-
def _table_name_of_path(self, table_path):
422-
# TODO: Add >= 128-character truncation and indexing scheme
423-
return table_path.split("/")[-1]
424-
425473
def _init_tables_from_mapping(self, X):
426474
"""Initializes the table spec from a dict-like 'X'"""
427475
assert is_dict_like(X), "'X' must be dict-like"
428476

477+
# Detect if deprecated mapping specification syntax is used;
478+
# if so, issue deprecation warning and transform it to the new syntax
479+
if "tables" in X.keys():
480+
warnings.warn(
481+
deprecation_message(
482+
"This multi-table dataset specification format",
483+
"11.0.1",
484+
replacement=(
485+
"the new data-path-based format, as documented in "
486+
":doc:`multi_table_primer`."
487+
),
488+
quote=False,
489+
)
490+
)
491+
X = _upgrade_mapping_spec(X)
492+
429493
# Check the input mapping
430494
check_dataset_spec(X)
431495

@@ -444,7 +508,7 @@ def _init_tables_from_mapping(self, X):
444508
if "additional_data_tables" in X:
445509
for table_path, table_spec in X["additional_data_tables"].items():
446510
table_source, table_key = table_spec[:2]
447-
table_name = self._table_name_of_path(table_path)
511+
table_name = _table_name_of_path(table_path)
448512
table = PandasTable(
449513
table_name,
450514
table_source,
@@ -461,7 +525,7 @@ def _init_tables_from_mapping(self, X):
461525
parent_table_name = self.main_table.name
462526
else:
463527
table_path_fragments = table_path.split("/")
464-
parent_table_name = self._table_name_of_path(
528+
parent_table_name = _table_name_of_path(
465529
"/".join(table_path_fragments[:-1])
466530
)
467531
self.relations.append(

tests/test_dataset_class.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import shutil
1010
import unittest
11+
import warnings
1112

1213
import numpy as np
1314
import pandas as pd
@@ -16,7 +17,7 @@
1617
from pandas.testing import assert_frame_equal
1718
from sklearn import datasets
1819

19-
from khiops.sklearn.dataset import Dataset
20+
from khiops.sklearn.dataset import Dataset, _upgrade_mapping_spec
2021

2122

2223
class DatasetInputOutputConsistencyTests(unittest.TestCase):
@@ -353,6 +354,102 @@ def get_ref_var_types(self, multitable, schema=None):
353354

354355
return ref_var_types
355356

357+
def test_dataset_of_deprecated_mt_mapping(self):
358+
"""Test deprecated multi-table specification handling"""
359+
(
360+
ref_main_table,
361+
ref_secondary_table_1,
362+
ref_secondary_table_2,
363+
ref_tertiary_table,
364+
ref_quaternary_table,
365+
) = self.create_multitable_snowflake_dataframes()
366+
367+
features_ref_main_table = ref_main_table.drop("class", axis=1)
368+
expected_ds_spec = {
369+
"main_table": (features_ref_main_table, ["User_ID"]),
370+
"additional_data_tables": {
371+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"], False),
372+
"B/D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"], False),
373+
"B/D/E": (
374+
ref_quaternary_table,
375+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
376+
),
377+
"C": (ref_secondary_table_2, ["User_ID"], True),
378+
},
379+
}
380+
deprecated_ds_spec = {
381+
"main_table": "A",
382+
"tables": {
383+
"A": (features_ref_main_table, "User_ID"),
384+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"]),
385+
"C": (ref_secondary_table_2, "User_ID"),
386+
"D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"]),
387+
"E": (
388+
ref_quaternary_table,
389+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
390+
),
391+
},
392+
"relations": {
393+
("A", "B", False),
394+
("B", "D", False),
395+
("D", "E"),
396+
("A", "C", True),
397+
},
398+
}
399+
400+
label = ref_main_table["class"]
401+
402+
# Test that deprecation warning is issued when creating a dataset
403+
# according to the deprecated spec
404+
with warnings.catch_warnings(record=True) as warning_list:
405+
_ = Dataset(deprecated_ds_spec, label)
406+
self.assertTrue(len(warning_list) > 0)
407+
deprecation_warning_found = False
408+
for warning in warning_list:
409+
warning_message = warning.message
410+
if (
411+
issubclass(warning.category, UserWarning)
412+
and len(warning_message.args) == 1
413+
and "multi-table dataset specification format"
414+
in warning_message.args[0]
415+
and "deprecated" in warning_message.args[0]
416+
):
417+
deprecation_warning_found = True
418+
break
419+
self.assertTrue(deprecation_warning_found)
420+
421+
# Test that a deprecated dataset spec is upgraded to the new format
422+
ds_spec = _upgrade_mapping_spec(deprecated_ds_spec)
423+
self.assertEqual(ds_spec.keys(), expected_ds_spec.keys())
424+
main_table = ds_spec["main_table"]
425+
expected_main_table = expected_ds_spec["main_table"]
426+
427+
# Test that main table keys are identical
428+
self.assertEqual(main_table[1], expected_main_table[1])
429+
430+
# Test that main table data frame are equal
431+
assert_frame_equal(main_table[0], expected_main_table[0])
432+
433+
# Test that additional data tables keys are identical
434+
additional_data_tables = ds_spec["additional_data_tables"]
435+
expected_additional_data_tables = expected_ds_spec["additional_data_tables"]
436+
self.assertEqual(
437+
additional_data_tables.keys(), expected_additional_data_tables.keys()
438+
)
439+
440+
for table_path, expected_table_data in expected_additional_data_tables.items():
441+
table_data = additional_data_tables[table_path]
442+
443+
# Test that secondary table keys are identical
444+
self.assertEqual(table_data[1], expected_table_data[1])
445+
446+
# Test that the secondary table data frames are identical
447+
assert_frame_equal(table_data[0], expected_table_data[0])
448+
449+
# Test that the secondary table entity statuses are identical if True
450+
if len(expected_table_data) > 2 and expected_table_data[2] is True:
451+
self.assertEqual(table_data[2], expected_table_data[2])
452+
356453
def test_dataset_is_correctly_built(self):
357454
"""Test that the dataset structure is consistent with the input spec"""
358455
ds_spec, label = self.create_fixture_ds_spec(

0 commit comments

Comments
 (0)