Skip to content

Commit 05fd0cf

Browse files
committed
Drop Dataset.relations and make Dataset data-path-aware
1 parent 586cb28 commit 05fd0cf

File tree

4 files changed

+79
-70
lines changed

4 files changed

+79
-70
lines changed

khiops/sklearn/dataset.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _check_multitable_spec(ds_spec):
177177
)
178178

179179

180-
def _table_name_of_path(table_path):
180+
def table_name_of_path(table_path):
181181
return table_path.split("/")[-1]
182182

183183

@@ -378,7 +378,6 @@ def __init__(self, X, y=None, categorical_target=True):
378378
# Initialize members
379379
self.main_table = None
380380
self.additional_data_tables = None
381-
self.relations = None
382381
self.categorical_target = categorical_target
383382
self.target_column = None
384383
self.target_column_id = None
@@ -428,7 +427,8 @@ def __init__(self, X, y=None, categorical_target=True):
428427
# Index the tables by name
429428
self._tables_by_name = {
430429
table.name: table
431-
for table in [self.main_table] + self.additional_data_tables
430+
for table in [self.main_table]
431+
+ [table for _, table, _ in self.additional_data_tables]
432432
}
433433

434434
# Post-conditions
@@ -504,32 +504,21 @@ def _init_tables_from_mapping(self, X):
504504
key=main_table_key,
505505
)
506506
self.additional_data_tables = []
507-
self.relations = []
508507
if "additional_data_tables" in X:
509508
for table_path, table_spec in X["additional_data_tables"].items():
510509
table_source, table_key = table_spec[:2]
511-
table_name = _table_name_of_path(table_path)
510+
table_name = table_name_of_path(table_path)
512511
table = PandasTable(
513512
table_name,
514513
table_source,
515-
data_path=table_path,
516514
key=table_key,
517515
)
518-
self.additional_data_tables.append(table)
519516
is_one_to_one_relation = False
520517
if len(table_spec) == 3 and table_spec[2] is True:
521518
is_one_to_one_relation = True
522519

523-
# Set relation parent: if no "/" in path, main_table is the parent
524-
if not "/" in table_path:
525-
parent_table_name = self.main_table.name
526-
else:
527-
table_path_fragments = table_path.split("/")
528-
parent_table_name = _table_name_of_path(
529-
"/".join(table_path_fragments[:-1])
530-
)
531-
self.relations.append(
532-
(parent_table_name, table_name, is_one_to_one_relation)
520+
self.additional_data_tables.append(
521+
(table_path, table, is_one_to_one_relation)
533522
)
534523
# Initialize a sparse dataset (monotable)
535524
elif isinstance(main_table_source, sp.spmatrix):
@@ -539,7 +528,6 @@ def _init_tables_from_mapping(self, X):
539528
key=main_table_key,
540529
)
541530
self.additional_data_tables = []
542-
self.relations = []
543531
# Initialize a numpyarray dataset (monotable)
544532
elif hasattr(main_table_source, "__array__"):
545533
self.main_table = NumpyTable(
@@ -552,7 +540,6 @@ def _init_tables_from_mapping(self, X):
552540
"with pandas dataframe source tables"
553541
)
554542
self.additional_data_tables = []
555-
self.relations = []
556543
else:
557544
raise TypeError(
558545
type_error_message(
@@ -671,11 +658,12 @@ def to_spec(self):
671658
ds_spec = {}
672659
ds_spec["main_table"] = (self.main_table.data_source, self.main_table.key)
673660
ds_spec["additional_data_tables"] = {}
674-
for table in self.additional_data_tables:
675-
assert table.data_path is not None
676-
ds_spec["additional_data_tables"][table.data_path] = (
661+
for table_path, table, is_one_to_one_relation in self.additional_data_tables:
662+
assert table_path is not None
663+
ds_spec["additional_data_tables"][table_path] = (
677664
table.data_source,
678665
table.key,
666+
is_one_to_one_relation,
679667
)
680668

681669
return ds_spec
@@ -739,31 +727,32 @@ def create_khiops_dictionary_domain(self):
739727
# Note: In general 'name' and 'object_type' fields of Variable can be different
740728
if self.additional_data_tables:
741729
main_dictionary.root = True
742-
table_names = [table.name for table in self.additional_data_tables]
743-
tables_to_visit = [self.main_table.name]
744-
while tables_to_visit:
745-
current_table = tables_to_visit.pop(0)
746-
for relation in self.relations:
747-
parent_table, child_table, is_one_to_one_relation = relation
748-
if parent_table == current_table:
749-
tables_to_visit.append(child_table)
750-
parent_table_name = parent_table
751-
index_table = table_names.index(child_table)
752-
table = self.additional_data_tables[index_table]
753-
parent_table_dictionary = dictionary_domain.get_dictionary(
754-
parent_table_name
755-
)
756-
dictionary = table.create_khiops_dictionary()
757-
dictionary_domain.add_dictionary(dictionary)
758-
table_variable = kh.Variable()
759-
if is_one_to_one_relation:
760-
table_variable.type = "Entity"
761-
else:
762-
table_variable.type = "Table"
763-
table_variable.name = table.name
764-
table_variable.object_type = table.name
765-
parent_table_dictionary.add_variable(table_variable)
730+
for (
731+
table_path,
732+
table,
733+
is_one_to_one_relation,
734+
) in self.additional_data_tables:
735+
if not "/" in table_path:
736+
parent_table_name = self.main_table.name
737+
else:
738+
table_path_fragments = table_path.split("/")
739+
parent_table_name = table_name_of_path(
740+
"/".join(table_path_fragments[:-1])
741+
)
742+
parent_table_dictionary = dictionary_domain.get_dictionary(
743+
parent_table_name
744+
)
766745

746+
dictionary = table.create_khiops_dictionary()
747+
dictionary_domain.add_dictionary(dictionary)
748+
table_variable = kh.Variable()
749+
if is_one_to_one_relation:
750+
table_variable.type = "Entity"
751+
else:
752+
table_variable.type = "Table"
753+
table_variable.name = table.name
754+
table_variable.object_type = table.name
755+
parent_table_dictionary.add_variable(table_variable)
767756
return dictionary_domain
768757

769758
def create_table_files_for_khiops(self, output_dir, sort=True):
@@ -802,9 +791,9 @@ def create_table_files_for_khiops(self, output_dir, sort=True):
802791

803792
# Create a copy of each secondary table
804793
secondary_table_paths = {}
805-
for table in self.additional_data_tables:
806-
assert table.data_path is not None
807-
secondary_table_paths[table.data_path] = table.create_table_file_for_khiops(
794+
for table_path, table, _ in self.additional_data_tables:
795+
assert table_path is not None
796+
secondary_table_paths[table_path] = table.create_table_file_for_khiops(
808797
output_dir, sort=sort
809798
)
810799

@@ -909,13 +898,11 @@ class PandasTable(DatasetTable):
909898
Name for the table.
910899
dataframe : `pandas.DataFrame`
911900
The data frame to be encapsulated. It must be non-empty.
912-
data_path : str, optional
913-
Data path of the table. Unset for main tables.
914901
key : list of str, optional
915902
The names of the columns composing the key.
916903
"""
917904

918-
def __init__(self, name, dataframe, data_path=None, key=None):
905+
def __init__(self, name, dataframe, key=None):
919906
# Call the parent method
920907
super().__init__(name=name, key=key)
921908

@@ -928,7 +915,6 @@ def __init__(self, name, dataframe, data_path=None, key=None):
928915
# Initialize the attributes
929916
self.data_source = dataframe
930917
self.n_samples = len(self.data_source)
931-
self.data_path = data_path
932918

933919
# Initialize feature columns and verify their types
934920
self.column_ids = self.data_source.columns.values

khiops/sklearn/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ def _transform_check_dataset(self, ds):
15091509

15101510
# Multi-table model: Check name and dictionary coherence of secondary tables
15111511
dataset_secondary_tables_by_name = {
1512-
table.name: table for table in ds.additional_data_tables
1512+
table.name: table for _, table, _ in ds.additional_data_tables
15131513
}
15141514
for dictionary in self.model_.dictionaries:
15151515
assert dictionary.name.startswith(self._khiops_model_prefix), (

khiops/sklearn/helpers.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.model_selection import train_test_split
1212

1313
from khiops.core.internals.common import is_dict_like, type_error_message
14-
from khiops.sklearn.dataset import Dataset
14+
from khiops.sklearn.dataset import Dataset, table_name_of_path
1515

1616
# Note: We build the splits with lists and itertools.chain avoid pylint warning about
1717
# unbalanced-tuple-unpacking. See issue https://github.com/pylint-dev/pylint/issues/5671
@@ -122,15 +122,31 @@ def _train_test_split_in_memory_dataset(ds, y, test_size, sklearn_split_params=N
122122

123123
# Split the secondary tables tables
124124
# Note: The tables are traversed in BFS
125-
todo_relations = [
126-
relation for relation in ds.relations if relation[0] == ds.main_table.name
125+
todo_tables = [
126+
(table_path, table)
127+
for table_path, table, _ in ds.additional_data_tables
128+
if "/" not in table_path
127129
]
128-
while todo_relations:
129-
current_parent_table_name, current_child_table_name, _ = todo_relations.pop(0)
130-
for relation in ds.relations:
131-
parent_table_name, _, _ = relation
130+
while todo_tables:
131+
current_table_path, current_table = todo_tables.pop(0)
132+
if "/" not in current_table_path:
133+
current_parent_table_name = ds.main_table.name
134+
else:
135+
table_path_fragments = current_table_path.split("/")
136+
current_parent_table_name = table_name_of_path(
137+
"/".join(table_path_fragments[:-1])
138+
)
139+
current_child_table_name = current_table.name
140+
for secondary_table_path, secondary_table, _ in ds.additional_data_tables:
141+
if "/" not in secondary_table_path:
142+
parent_table_name = ds.main_table.name
143+
else:
144+
table_path_fragments = secondary_table_path.split("/")
145+
parent_table_name = table_name_of_path(
146+
"/".join(table_path_fragments[:-1])
147+
)
132148
if parent_table_name == current_child_table_name:
133-
todo_relations.append(relation)
149+
todo_tables.append((secondary_table_path, secondary_table))
134150

135151
for new_ds in (train_ds, test_ds):
136152
origin_child_table = ds.get_table(current_child_table_name)

tests/test_dataset_class.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,19 +460,22 @@ def test_dataset_is_correctly_built(self):
460460
self.assertEqual(dataset.main_table.name, "main_table")
461461
self.assertEqual(len(dataset.additional_data_tables), 4)
462462
dataset_secondary_table_names = {
463-
secondary_table.name for secondary_table in dataset.additional_data_tables
463+
secondary_table.name
464+
for _, secondary_table, _ in dataset.additional_data_tables
464465
}
465466
self.assertEqual(dataset_secondary_table_names, {"B", "C", "D", "E"})
466-
self.assertEqual(len(dataset.relations), 4)
467467

468468
table_specs = ds_spec["additional_data_tables"].items()
469-
for relation, (table_path, table_spec) in zip(dataset.relations, table_specs):
469+
for (ds_table_path, _, ds_is_one_to_one), (
470+
table_path,
471+
table_spec,
472+
) in zip(dataset.additional_data_tables, table_specs):
470473
# The relation holds the table name, not the table path
471-
self.assertEqual(relation[1], table_path.split("/")[-1])
474+
self.assertEqual(ds_table_path, table_path)
472475
if len(table_spec) == 3:
473-
self.assertEqual(relation[2], table_spec[2])
476+
self.assertEqual(ds_is_one_to_one, table_spec[2])
474477
else:
475-
self.assertFalse(relation[2])
478+
self.assertFalse(ds_is_one_to_one)
476479

477480
def test_out_file_from_dataframe_monotable(self):
478481
"""Test consistency of the created data file with the input dataframe
@@ -746,7 +749,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
746749

747750
# Check that the domain has the same table names as the reference
748751
ref_table_names = {
749-
table.name for table in [ds.main_table] + ds.additional_data_tables
752+
table.name
753+
for table in [ds.main_table]
754+
+ [table for _, table, _ in ds.additional_data_tables]
750755
}
751756
out_table_names = {dictionary.name for dictionary in out_domain.dictionaries}
752757
self.assertEqual(ref_table_names, out_table_names)
@@ -759,7 +764,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
759764
# Check that:
760765
# - the table keys are the same as the dataset
761766
# - the domain has the same variable names as the reference
762-
for table in [ds.main_table] + ds.additional_data_tables:
767+
for table in [ds.main_table] + [
768+
table for _, table, _ in ds.additional_data_tables
769+
]:
763770
with self.subTest(table=table.name):
764771
self.assertEqual(table.key, out_domain.get_dictionary(table.name).key)
765772
out_dictionary_var_types = {

0 commit comments

Comments
 (0)