Skip to content

Commit 9b2a022

Browse files
committed
Drop Dataset.relations and make Dataset data-path-aware
1 parent c8cf654 commit 9b2a022

4 files changed

Lines changed: 79 additions & 70 deletions

File tree

khiops/sklearn/dataset.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _check_multitable_spec(ds_spec):
176176
)
177177

178178

179-
def _table_name_of_path(table_path):
179+
def table_name_of_path(table_path):
180180
return table_path.split("/")[-1]
181181

182182

@@ -377,7 +377,6 @@ def __init__(self, X, y=None, categorical_target=True):
377377
# Initialize members
378378
self.main_table = None
379379
self.additional_data_tables = None
380-
self.relations = None
381380
self.categorical_target = categorical_target
382381
self.target_column = None
383382
self.target_column_id = None
@@ -427,7 +426,8 @@ def __init__(self, X, y=None, categorical_target=True):
427426
# Index the tables by name
428427
self._tables_by_name = {
429428
table.name: table
430-
for table in [self.main_table] + self.additional_data_tables
429+
for table in [self.main_table]
430+
+ [table for _, table, _ in self.additional_data_tables]
431431
}
432432

433433
# Post-conditions
@@ -503,32 +503,21 @@ def _init_tables_from_mapping(self, X):
503503
key=main_table_key,
504504
)
505505
self.additional_data_tables = []
506-
self.relations = []
507506
if "additional_data_tables" in X:
508507
for table_path, table_spec in X["additional_data_tables"].items():
509508
table_source, table_key = table_spec[:2]
510-
table_name = _table_name_of_path(table_path)
509+
table_name = table_name_of_path(table_path)
511510
table = PandasTable(
512511
table_name,
513512
table_source,
514-
data_path=table_path,
515513
key=table_key,
516514
)
517-
self.additional_data_tables.append(table)
518515
is_one_to_one_relation = False
519516
if len(table_spec) == 3 and table_spec[2] is True:
520517
is_one_to_one_relation = True
521518

522-
# Set relation parent: if no "/" in path, main_table is the parent
523-
if not "/" in table_path:
524-
parent_table_name = self.main_table.name
525-
else:
526-
table_path_fragments = table_path.split("/")
527-
parent_table_name = _table_name_of_path(
528-
"/".join(table_path_fragments[:-1])
529-
)
530-
self.relations.append(
531-
(parent_table_name, table_name, is_one_to_one_relation)
519+
self.additional_data_tables.append(
520+
(table_path, table, is_one_to_one_relation)
532521
)
533522
# Initialize a sparse dataset (monotable)
534523
elif isinstance(main_table_source, sp.spmatrix):
@@ -538,7 +527,6 @@ def _init_tables_from_mapping(self, X):
538527
key=main_table_key,
539528
)
540529
self.additional_data_tables = []
541-
self.relations = []
542530
# Initialize a numpyarray dataset (monotable)
543531
elif hasattr(main_table_source, "__array__"):
544532
self.main_table = NumpyTable(
@@ -551,7 +539,6 @@ def _init_tables_from_mapping(self, X):
551539
"with pandas dataframe source tables"
552540
)
553541
self.additional_data_tables = []
554-
self.relations = []
555542
else:
556543
raise TypeError(
557544
type_error_message(
@@ -670,11 +657,12 @@ def to_spec(self):
670657
ds_spec = {}
671658
ds_spec["main_table"] = (self.main_table.data_source, self.main_table.key)
672659
ds_spec["additional_data_tables"] = {}
673-
for table in self.additional_data_tables:
674-
assert table.data_path is not None
675-
ds_spec["additional_data_tables"][table.data_path] = (
660+
for table_path, table, is_one_to_one_relation in self.additional_data_tables:
661+
assert table_path is not None
662+
ds_spec["additional_data_tables"][table_path] = (
676663
table.data_source,
677664
table.key,
665+
is_one_to_one_relation,
678666
)
679667

680668
return ds_spec
@@ -738,31 +726,32 @@ def create_khiops_dictionary_domain(self):
738726
# Note: In general 'name' and 'object_type' fields of Variable can be different
739727
if self.additional_data_tables:
740728
main_dictionary.root = True
741-
table_names = [table.name for table in self.additional_data_tables]
742-
tables_to_visit = [self.main_table.name]
743-
while tables_to_visit:
744-
current_table = tables_to_visit.pop(0)
745-
for relation in self.relations:
746-
parent_table, child_table, is_one_to_one_relation = relation
747-
if parent_table == current_table:
748-
tables_to_visit.append(child_table)
749-
parent_table_name = parent_table
750-
index_table = table_names.index(child_table)
751-
table = self.additional_data_tables[index_table]
752-
parent_table_dictionary = dictionary_domain.get_dictionary(
753-
parent_table_name
754-
)
755-
dictionary = table.create_khiops_dictionary()
756-
dictionary_domain.add_dictionary(dictionary)
757-
table_variable = kh.Variable()
758-
if is_one_to_one_relation:
759-
table_variable.type = "Entity"
760-
else:
761-
table_variable.type = "Table"
762-
table_variable.name = table.name
763-
table_variable.object_type = table.name
764-
parent_table_dictionary.add_variable(table_variable)
729+
for (
730+
table_path,
731+
table,
732+
is_one_to_one_relation,
733+
) in self.additional_data_tables:
734+
if not "/" in table_path:
735+
parent_table_name = self.main_table.name
736+
else:
737+
table_path_fragments = table_path.split("/")
738+
parent_table_name = table_name_of_path(
739+
"/".join(table_path_fragments[:-1])
740+
)
741+
parent_table_dictionary = dictionary_domain.get_dictionary(
742+
parent_table_name
743+
)
765744

745+
dictionary = table.create_khiops_dictionary()
746+
dictionary_domain.add_dictionary(dictionary)
747+
table_variable = kh.Variable()
748+
if is_one_to_one_relation:
749+
table_variable.type = "Entity"
750+
else:
751+
table_variable.type = "Table"
752+
table_variable.name = table.name
753+
table_variable.object_type = table.name
754+
parent_table_dictionary.add_variable(table_variable)
766755
return dictionary_domain
767756

768757
def create_table_files_for_khiops(self, output_dir, sort=True):
@@ -801,9 +790,9 @@ def create_table_files_for_khiops(self, output_dir, sort=True):
801790

802791
# Create a copy of each secondary table
803792
secondary_table_paths = {}
804-
for table in self.additional_data_tables:
805-
assert table.data_path is not None
806-
secondary_table_paths[table.data_path] = table.create_table_file_for_khiops(
793+
for table_path, table, _ in self.additional_data_tables:
794+
assert table_path is not None
795+
secondary_table_paths[table_path] = table.create_table_file_for_khiops(
807796
output_dir, sort=sort
808797
)
809798

@@ -908,13 +897,11 @@ class PandasTable(DatasetTable):
908897
Name for the table.
909898
dataframe : `pandas.DataFrame`
910899
The data frame to be encapsulated. It must be non-empty.
911-
data_path : str, optional
912-
Data path of the table. Unset for main tables.
913900
key : list of str, optional
914901
The names of the columns composing the key.
915902
"""
916903

917-
def __init__(self, name, dataframe, data_path=None, key=None):
904+
def __init__(self, name, dataframe, key=None):
918905
# Call the parent method
919906
super().__init__(name=name, key=key)
920907

@@ -927,7 +914,6 @@ def __init__(self, name, dataframe, data_path=None, key=None):
927914
# Initialize the attributes
928915
self.data_source = dataframe
929916
self.n_samples = len(self.data_source)
930-
self.data_path = data_path
931917

932918
# Initialize feature columns and verify their types
933919
self.column_ids = self.data_source.columns.values

khiops/sklearn/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ def _transform_check_dataset(self, ds):
15091509

15101510
# Multi-table model: Check name and dictionary coherence of secondary tables
15111511
dataset_secondary_tables_by_name = {
1512-
table.name: table for table in ds.additional_data_tables
1512+
table.name: table for _, table, _ in ds.additional_data_tables
15131513
}
15141514
for dictionary in self.model_.dictionaries:
15151515
assert dictionary.name.startswith(self._khiops_model_prefix), (

khiops/sklearn/helpers.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.model_selection import train_test_split
1212

1313
from khiops.core.internals.common import is_dict_like, type_error_message
14-
from khiops.sklearn.dataset import Dataset
14+
from khiops.sklearn.dataset import Dataset, table_name_of_path
1515

1616
# Note: We build the splits with lists and itertools.chain avoid pylint warning about
1717
# unbalanced-tuple-unpacking. See issue https://github.com/pylint-dev/pylint/issues/5671
@@ -122,15 +122,31 @@ def _train_test_split_in_memory_dataset(ds, y, test_size, sklearn_split_params=N
122122

123123
# Split the secondary tables tables
124124
# Note: The tables are traversed in BFS
125-
todo_relations = [
126-
relation for relation in ds.relations if relation[0] == ds.main_table.name
125+
todo_tables = [
126+
(table_path, table)
127+
for table_path, table, _ in ds.additional_data_tables
128+
if "/" not in table_path
127129
]
128-
while todo_relations:
129-
current_parent_table_name, current_child_table_name, _ = todo_relations.pop(0)
130-
for relation in ds.relations:
131-
parent_table_name, _, _ = relation
130+
while todo_tables:
131+
current_table_path, current_table = todo_tables.pop(0)
132+
if "/" not in current_table_path:
133+
current_parent_table_name = ds.main_table.name
134+
else:
135+
table_path_fragments = current_table_path.split("/")
136+
current_parent_table_name = table_name_of_path(
137+
"/".join(table_path_fragments[:-1])
138+
)
139+
current_child_table_name = current_table.name
140+
for secondary_table_path, secondary_table, _ in ds.additional_data_tables:
141+
if "/" not in secondary_table_path:
142+
parent_table_name = ds.main_table.name
143+
else:
144+
table_path_fragments = secondary_table_path.split("/")
145+
parent_table_name = table_name_of_path(
146+
"/".join(table_path_fragments[:-1])
147+
)
132148
if parent_table_name == current_child_table_name:
133-
todo_relations.append(relation)
149+
todo_tables.append((secondary_table_path, secondary_table))
134150

135151
for new_ds in (train_ds, test_ds):
136152
origin_child_table = ds.get_table(current_child_table_name)

tests/test_dataset_class.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,19 +460,22 @@ def test_dataset_is_correctly_built(self):
460460
self.assertEqual(dataset.main_table.name, "main_table")
461461
self.assertEqual(len(dataset.additional_data_tables), 4)
462462
dataset_secondary_table_names = {
463-
secondary_table.name for secondary_table in dataset.additional_data_tables
463+
secondary_table.name
464+
for _, secondary_table, _ in dataset.additional_data_tables
464465
}
465466
self.assertEqual(dataset_secondary_table_names, {"B", "C", "D", "E"})
466-
self.assertEqual(len(dataset.relations), 4)
467467

468468
table_specs = ds_spec["additional_data_tables"].items()
469-
for relation, (table_path, table_spec) in zip(dataset.relations, table_specs):
469+
for (ds_table_path, _, ds_is_one_to_one), (
470+
table_path,
471+
table_spec,
472+
) in zip(dataset.additional_data_tables, table_specs):
470473
# The relation holds the table name, not the table path
471-
self.assertEqual(relation[1], table_path.split("/")[-1])
474+
self.assertEqual(ds_table_path, table_path)
472475
if len(table_spec) == 3:
473-
self.assertEqual(relation[2], table_spec[2])
476+
self.assertEqual(ds_is_one_to_one, table_spec[2])
474477
else:
475-
self.assertFalse(relation[2])
478+
self.assertFalse(ds_is_one_to_one)
476479

477480
def test_out_file_from_dataframe_monotable(self):
478481
"""Test consistency of the created data file with the input dataframe
@@ -746,7 +749,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
746749

747750
# Check that the domain has the same table names as the reference
748751
ref_table_names = {
749-
table.name for table in [ds.main_table] + ds.additional_data_tables
752+
table.name
753+
for table in [ds.main_table]
754+
+ [table for _, table, _ in ds.additional_data_tables]
750755
}
751756
out_table_names = {dictionary.name for dictionary in out_domain.dictionaries}
752757
self.assertEqual(ref_table_names, out_table_names)
@@ -759,7 +764,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
759764
# Check that:
760765
# - the table keys are the same as the dataset
761766
# - the domain has the same variable names as the reference
762-
for table in [ds.main_table] + ds.additional_data_tables:
767+
for table in [ds.main_table] + [
768+
table for _, table, _ in ds.additional_data_tables
769+
]:
763770
with self.subTest(table=table.name):
764771
self.assertEqual(table.key, out_domain.get_dictionary(table.name).key)
765772
out_dictionary_var_types = {

0 commit comments

Comments
 (0)