Skip to content

Commit 632bdf1

Browse files
Control the types on sklearn internal read table
We do this only for KhiopsClassifier and KhiopsRegressor: It is critical for KhiopsClassifier as it accepts many target types and it is trivial in the case of KhiopsRegressor. For KhiopsEncoder and KhiopsCoclustering is less critical and for the first one it is very complex. We left them as TODO's. Additionaly, we now also check in the "output type" tests that the result of predict is correct. Before we only checked only that the classes_ attribute was ok. This is to further ensure correctness.
1 parent 131974c commit 632bdf1

3 files changed

Lines changed: 214 additions & 122 deletions

File tree

khiops/sklearn/dataset.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def get_khiops_variable_name(column_id):
346346
return variable_name
347347

348348

349-
def read_internal_data_table(file_path_or_stream):
349+
def read_internal_data_table(file_path_or_stream, column_dtypes=None):
350350
"""Reads into a DataFrame a data table file with the internal format settings
351351
352352
The table is read with the following settings:
@@ -357,18 +357,34 @@ def read_internal_data_table(file_path_or_stream):
357357
- Use `csv.QUOTE_MINIMAL`
358358
- double quoting enabled (quotes within quotes can be escaped with '""')
359359
- UTF-8 encoding
360+
- User-specified dtypes (optional)
360361
361362
Parameters
362363
----------
363364
file_path_or_stream : str or file object
364365
The path of the internal data table file to be read or a readable file
365366
object.
367+
column_dtypes : dict, optional
368+
Dictionary linking column names with dtypes. See ``dtype`` parameter of the
369+
`pandas.read_csv` function. If not set, then the column types are detected
370+
automatically by pandas.
366371
367372
Returns
368373
-------
369374
`pandas.DataFrame`
370-
The dataframe representation.
375+
The dataframe representation of the data table.
371376
"""
377+
# Change the 'U' types (Unicode strings) to 'O' because pandas does not support them
378+
# in read_csv
379+
if column_dtypes is not None:
380+
execution_column_dtypes = {}
381+
for column_name, dtype in column_dtypes.items():
382+
if hasattr(dtype, "kind") and dtype.kind == "U":
383+
execution_column_dtypes[column_name] = np.dtype("O")
384+
else:
385+
execution_column_dtypes = None
386+
387+
# Read and return the dataframe
372388
return pd.read_csv(
373389
file_path_or_stream,
374390
sep="\t",
@@ -377,6 +393,7 @@ def read_internal_data_table(file_path_or_stream):
377393
quoting=csv.QUOTE_MINIMAL,
378394
doublequote=True,
379395
encoding="utf-8",
396+
dtype=execution_column_dtypes,
380397
)
381398

382399

@@ -1132,6 +1149,12 @@ def __repr__(self):
11321149
f"dtypes={dtypes_str}>"
11331150
)
11341151

1152+
def get_column_dtype(self, column_id):
1153+
assert (
1154+
column_id in self.data_source.dtypes
1155+
), f"Column '{column_id}' not found in dtypes field"
1156+
return self.data_source.dtypes[column_id]
1157+
11351158
def create_table_file_for_khiops(
11361159
self, output_dir, sort=True, target_column=None, target_column_id=None
11371160
):
@@ -1214,6 +1237,9 @@ def __repr__(self):
12141237
f"dtype={dtype_str}; target={self.target_column_id}>"
12151238
)
12161239

1240+
def get_column_dtype(self, _):
1241+
return self.data_source.dtype
1242+
12171243
def create_table_file_for_khiops(
12181244
self, output_dir, sort=True, target_column=None, target_column_id=None
12191245
):
@@ -1300,6 +1326,9 @@ def __repr__(self):
13001326
f"dtype={dtype_str}>"
13011327
)
13021328

1329+
def get_column_dtype(self, _):
1330+
return self.data_source.dtype
1331+
13031332
def create_khiops_dictionary(self):
13041333
"""Creates a Khiops dictionary representing this sparse table
13051334

0 commit comments

Comments
 (0)