Skip to content

Commit ddaa5fc

Browse files
committed
Drop Dataset.relations and make Dataset data-path-aware
1 parent c7f1faa commit ddaa5fc

File tree

4 files changed

+79
-70
lines changed

4 files changed

+79
-70
lines changed

khiops/sklearn/dataset.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _check_multitable_spec(ds_spec):
176176
)
177177

178178

179-
def _table_name_of_path(table_path):
179+
def table_name_of_path(table_path):
180180
return table_path.split("/")[-1]
181181

182182

@@ -379,7 +379,6 @@ def __init__(self, X, y=None, categorical_target=True):
379379
# Initialize members
380380
self.main_table = None
381381
self.additional_data_tables = None
382-
self.relations = None
383382
self.categorical_target = categorical_target
384383
self.target_column = None
385384
self.target_column_id = None
@@ -429,7 +428,8 @@ def __init__(self, X, y=None, categorical_target=True):
429428
# Index the tables by name
430429
self._tables_by_name = {
431430
table.name: table
432-
for table in [self.main_table] + self.additional_data_tables
431+
for table in [self.main_table]
432+
+ [table for _, table, _ in self.additional_data_tables]
433433
}
434434

435435
# Post-conditions
@@ -505,32 +505,21 @@ def _init_tables_from_mapping(self, X):
505505
key=main_table_key,
506506
)
507507
self.additional_data_tables = []
508-
self.relations = []
509508
if "additional_data_tables" in X:
510509
for table_path, table_spec in X["additional_data_tables"].items():
511510
table_source, table_key = table_spec[:2]
512-
table_name = _table_name_of_path(table_path)
511+
table_name = table_name_of_path(table_path)
513512
table = PandasTable(
514513
table_name,
515514
table_source,
516-
data_path=table_path,
517515
key=table_key,
518516
)
519-
self.additional_data_tables.append(table)
520517
is_one_to_one_relation = False
521518
if len(table_spec) == 3 and table_spec[2] is True:
522519
is_one_to_one_relation = True
523520

524-
# Set relation parent: if no "/" in path, main_table is the parent
525-
if not "/" in table_path:
526-
parent_table_name = self.main_table.name
527-
else:
528-
table_path_fragments = table_path.split("/")
529-
parent_table_name = _table_name_of_path(
530-
"/".join(table_path_fragments[:-1])
531-
)
532-
self.relations.append(
533-
(parent_table_name, table_name, is_one_to_one_relation)
521+
self.additional_data_tables.append(
522+
(table_path, table, is_one_to_one_relation)
534523
)
535524
# Initialize a sparse dataset (monotable)
536525
elif isinstance(main_table_source, sp.spmatrix):
@@ -540,7 +529,6 @@ def _init_tables_from_mapping(self, X):
540529
key=main_table_key,
541530
)
542531
self.additional_data_tables = []
543-
self.relations = []
544532
# Initialize a numpyarray dataset (monotable)
545533
elif hasattr(main_table_source, "__array__"):
546534
self.main_table = NumpyTable(
@@ -553,7 +541,6 @@ def _init_tables_from_mapping(self, X):
553541
"with pandas dataframe source tables"
554542
)
555543
self.additional_data_tables = []
556-
self.relations = []
557544
else:
558545
raise TypeError(
559546
type_error_message(
@@ -672,11 +659,12 @@ def to_spec(self):
672659
ds_spec = {}
673660
ds_spec["main_table"] = (self.main_table.data_source, self.main_table.key)
674661
ds_spec["additional_data_tables"] = {}
675-
for table in self.additional_data_tables:
676-
assert table.data_path is not None
677-
ds_spec["additional_data_tables"][table.data_path] = (
662+
for table_path, table, is_one_to_one_relation in self.additional_data_tables:
663+
assert table_path is not None
664+
ds_spec["additional_data_tables"][table_path] = (
678665
table.data_source,
679666
table.key,
667+
is_one_to_one_relation,
680668
)
681669

682670
return ds_spec
@@ -740,31 +728,32 @@ def create_khiops_dictionary_domain(self):
740728
# Note: In general 'name' and 'object_type' fields of Variable can be different
741729
if self.additional_data_tables:
742730
main_dictionary.root = True
743-
table_names = [table.name for table in self.additional_data_tables]
744-
tables_to_visit = [self.main_table.name]
745-
while tables_to_visit:
746-
current_table = tables_to_visit.pop(0)
747-
for relation in self.relations:
748-
parent_table, child_table, is_one_to_one_relation = relation
749-
if parent_table == current_table:
750-
tables_to_visit.append(child_table)
751-
parent_table_name = parent_table
752-
index_table = table_names.index(child_table)
753-
table = self.additional_data_tables[index_table]
754-
parent_table_dictionary = dictionary_domain.get_dictionary(
755-
parent_table_name
756-
)
757-
dictionary = table.create_khiops_dictionary()
758-
dictionary_domain.add_dictionary(dictionary)
759-
table_variable = kh.Variable()
760-
if is_one_to_one_relation:
761-
table_variable.type = "Entity"
762-
else:
763-
table_variable.type = "Table"
764-
table_variable.name = table.name
765-
table_variable.object_type = table.name
766-
parent_table_dictionary.add_variable(table_variable)
731+
for (
732+
table_path,
733+
table,
734+
is_one_to_one_relation,
735+
) in self.additional_data_tables:
736+
if not "/" in table_path:
737+
parent_table_name = self.main_table.name
738+
else:
739+
table_path_fragments = table_path.split("/")
740+
parent_table_name = table_name_of_path(
741+
"/".join(table_path_fragments[:-1])
742+
)
743+
parent_table_dictionary = dictionary_domain.get_dictionary(
744+
parent_table_name
745+
)
767746

747+
dictionary = table.create_khiops_dictionary()
748+
dictionary_domain.add_dictionary(dictionary)
749+
table_variable = kh.Variable()
750+
if is_one_to_one_relation:
751+
table_variable.type = "Entity"
752+
else:
753+
table_variable.type = "Table"
754+
table_variable.name = table.name
755+
table_variable.object_type = table.name
756+
parent_table_dictionary.add_variable(table_variable)
768757
return dictionary_domain
769758

770759
def create_table_files_for_khiops(self, output_dir, sort=True):
@@ -803,9 +792,9 @@ def create_table_files_for_khiops(self, output_dir, sort=True):
803792

804793
# Create a copy of each secondary table
805794
secondary_table_paths = {}
806-
for table in self.additional_data_tables:
807-
assert table.data_path is not None
808-
secondary_table_paths[table.data_path] = table.create_table_file_for_khiops(
795+
for table_path, table, _ in self.additional_data_tables:
796+
assert table_path is not None
797+
secondary_table_paths[table_path] = table.create_table_file_for_khiops(
809798
output_dir, sort=sort
810799
)
811800

@@ -910,13 +899,11 @@ class PandasTable(DatasetTable):
910899
Name for the table.
911900
dataframe : `pandas.DataFrame`
912901
The data frame to be encapsulated. It must be non-empty.
913-
data_path : str, optional
914-
Data path of the table. Unset for main tables.
915902
key : list of str, optional
916903
The names of the columns composing the key.
917904
"""
918905

919-
def __init__(self, name, dataframe, data_path=None, key=None):
906+
def __init__(self, name, dataframe, key=None):
920907
# Call the parent method
921908
super().__init__(name=name, key=key)
922909

@@ -929,7 +916,6 @@ def __init__(self, name, dataframe, data_path=None, key=None):
929916
# Initialize the attributes
930917
self.data_source = dataframe
931918
self.n_samples = len(self.data_source)
932-
self.data_path = data_path
933919

934920
# Initialize feature columns and verify their types
935921
self.column_ids = self.data_source.columns.values

khiops/sklearn/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ def _transform_check_dataset(self, ds):
15091509

15101510
# Multi-table model: Check name and dictionary coherence of secondary tables
15111511
dataset_secondary_tables_by_name = {
1512-
table.name: table for table in ds.additional_data_tables
1512+
table.name: table for _, table, _ in ds.additional_data_tables
15131513
}
15141514
for dictionary in self.model_.dictionaries:
15151515
assert dictionary.name.startswith(self._khiops_model_prefix), (

khiops/sklearn/helpers.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.model_selection import train_test_split
1212

1313
from khiops.core.internals.common import is_dict_like, type_error_message
14-
from khiops.sklearn.dataset import Dataset
14+
from khiops.sklearn.dataset import Dataset, table_name_of_path
1515

1616
# Note: We build the splits with lists and itertools.chain avoid pylint warning about
1717
# unbalanced-tuple-unpacking. See issue https://github.com/pylint-dev/pylint/issues/5671
@@ -122,15 +122,31 @@ def _train_test_split_in_memory_dataset(ds, y, test_size, sklearn_split_params=N
122122

123123
# Split the secondary tables tables
124124
# Note: The tables are traversed in BFS
125-
todo_relations = [
126-
relation for relation in ds.relations if relation[0] == ds.main_table.name
125+
todo_tables = [
126+
(table_path, table)
127+
for table_path, table, _ in ds.additional_data_tables
128+
if "/" not in table_path
127129
]
128-
while todo_relations:
129-
current_parent_table_name, current_child_table_name, _ = todo_relations.pop(0)
130-
for relation in ds.relations:
131-
parent_table_name, _, _ = relation
130+
while todo_tables:
131+
current_table_path, current_table = todo_tables.pop(0)
132+
if "/" not in current_table_path:
133+
current_parent_table_name = ds.main_table.name
134+
else:
135+
table_path_fragments = current_table_path.split("/")
136+
current_parent_table_name = table_name_of_path(
137+
"/".join(table_path_fragments[:-1])
138+
)
139+
current_child_table_name = current_table.name
140+
for secondary_table_path, secondary_table, _ in ds.additional_data_tables:
141+
if "/" not in secondary_table_path:
142+
parent_table_name = ds.main_table.name
143+
else:
144+
table_path_fragments = secondary_table_path.split("/")
145+
parent_table_name = table_name_of_path(
146+
"/".join(table_path_fragments[:-1])
147+
)
132148
if parent_table_name == current_child_table_name:
133-
todo_relations.append(relation)
149+
todo_tables.append((secondary_table_path, secondary_table))
134150

135151
for new_ds in (train_ds, test_ds):
136152
origin_child_table = ds.get_table(current_child_table_name)

tests/test_dataset_class.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,19 +460,22 @@ def test_dataset_is_correctly_built(self):
460460
self.assertEqual(dataset.main_table.name, "main_table")
461461
self.assertEqual(len(dataset.additional_data_tables), 4)
462462
dataset_secondary_table_names = {
463-
secondary_table.name for secondary_table in dataset.additional_data_tables
463+
secondary_table.name
464+
for _, secondary_table, _ in dataset.additional_data_tables
464465
}
465466
self.assertEqual(dataset_secondary_table_names, {"B", "C", "D", "E"})
466-
self.assertEqual(len(dataset.relations), 4)
467467

468468
table_specs = ds_spec["additional_data_tables"].items()
469-
for relation, (table_path, table_spec) in zip(dataset.relations, table_specs):
469+
for (ds_table_path, _, ds_is_one_to_one), (
470+
table_path,
471+
table_spec,
472+
) in zip(dataset.additional_data_tables, table_specs):
470473
# The relation holds the table name, not the table path
471-
self.assertEqual(relation[1], table_path.split("/")[-1])
474+
self.assertEqual(ds_table_path, table_path)
472475
if len(table_spec) == 3:
473-
self.assertEqual(relation[2], table_spec[2])
476+
self.assertEqual(ds_is_one_to_one, table_spec[2])
474477
else:
475-
self.assertFalse(relation[2])
478+
self.assertFalse(ds_is_one_to_one)
476479

477480
def test_out_file_from_dataframe_monotable(self):
478481
"""Test consistency of the created data file with the input dataframe
@@ -746,7 +749,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
746749

747750
# Check that the domain has the same table names as the reference
748751
ref_table_names = {
749-
table.name for table in [ds.main_table] + ds.additional_data_tables
752+
table.name
753+
for table in [ds.main_table]
754+
+ [table for _, table, _ in ds.additional_data_tables]
750755
}
751756
out_table_names = {dictionary.name for dictionary in out_domain.dictionaries}
752757
self.assertEqual(ref_table_names, out_table_names)
@@ -759,7 +764,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
759764
# Check that:
760765
# - the table keys are the same as the dataset
761766
# - the domain has the same variable names as the reference
762-
for table in [ds.main_table] + ds.additional_data_tables:
767+
for table in [ds.main_table] + [
768+
table for _, table, _ in ds.additional_data_tables
769+
]:
763770
with self.subTest(table=table.name):
764771
self.assertEqual(table.key, out_domain.get_dictionary(table.name).key)
765772
out_dictionary_var_types = {

0 commit comments

Comments
 (0)