Skip to content

Commit 0527279

Browse files
Control the types on sklearn internal read table
We do this only for KhiopsClassifier and KhiopsRegressor: It is critical for KhiopsClassifier as it accepts many target types and it is trivial in the case of KhiopsRegressor. For KhiopsEncoder and KhiopsCoclustering is less critical and for the first one it is very complex. We left them as TODO's. Additionaly, we now also check in the "output type" tests that the result of predict is correct. Before we only checked only that the classes_ attribute was ok. This is to further ensure correctness.
1 parent 84da6f1 commit 0527279

File tree

3 files changed

+213
-122
lines changed

3 files changed

+213
-122
lines changed

khiops/sklearn/dataset.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def get_khiops_variable_name(column_id):
346346
return variable_name
347347

348348

349-
def read_internal_data_table(file_path_or_stream):
349+
def read_internal_data_table(file_path_or_stream, column_dtypes=None):
350350
"""Reads into a DataFrame a data table file with the internal format settings
351351
352352
The table is read with the following settings:
@@ -357,18 +357,34 @@ def read_internal_data_table(file_path_or_stream):
357357
- Use `csv.QUOTE_MINIMAL`
358358
- double quoting enabled (quotes within quotes can be escaped with '""')
359359
- UTF-8 encoding
360+
- User-specified dtypes (optional)
360361
361362
Parameters
362363
----------
363364
file_path_or_stream : str or file object
364365
The path of the internal data table file to be read or a readable file
365366
object.
367+
column_dtypes : dict, optional
368+
Dictionary linking column names with dtypes. See ``dtype`` parameter of the
369+
`pandas.read_csv` function. If not set, then the column types are detected
370+
automatically by pandas.
366371
367372
Returns
368373
-------
369374
`pandas.DataFrame`
370-
The dataframe representation.
375+
The dataframe representation of the data table.
371376
"""
377+
# Change the 'U' types (Unicode strings) to 'O' because pandas does not support them
378+
# in read_csv
379+
if column_dtypes is not None:
380+
execution_column_dtypes = {}
381+
for column_name, dtype in column_dtypes.items():
382+
if hasattr(dtype, "kind") and dtype.kind == "U":
383+
execution_column_dtypes[column_name] = np.dtype("O")
384+
else:
385+
execution_column_dtypes = None
386+
387+
# Read and return the dataframe
372388
return pd.read_csv(
373389
file_path_or_stream,
374390
sep="\t",
@@ -377,6 +393,7 @@ def read_internal_data_table(file_path_or_stream):
377393
quoting=csv.QUOTE_MINIMAL,
378394
doublequote=True,
379395
encoding="utf-8",
396+
dtype=execution_column_dtypes,
380397
)
381398

382399

@@ -1132,6 +1149,11 @@ def __repr__(self):
11321149
f"dtypes={dtypes_str}>"
11331150
)
11341151

1152+
def get_column_dtype(self, column_id):
1153+
if column_id not in self.data_source.dtypes:
1154+
raise KeyError(f"Column '{column_id}' not found in the dtypes field")
1155+
return self.data_source.dtypes[column_id]
1156+
11351157
def create_table_file_for_khiops(
11361158
self, output_dir, sort=True, target_column=None, target_column_id=None
11371159
):
@@ -1214,6 +1236,9 @@ def __repr__(self):
12141236
f"dtype={dtype_str}; target={self.target_column_id}>"
12151237
)
12161238

1239+
def get_column_dtype(self, _):
1240+
return self.data_source.dtype
1241+
12171242
def create_table_file_for_khiops(
12181243
self, output_dir, sort=True, target_column=None, target_column_id=None
12191244
):
@@ -1300,6 +1325,9 @@ def __repr__(self):
13001325
f"dtype={dtype_str}>"
13011326
)
13021327

1328+
def get_column_dtype(self, _):
1329+
return self.data_source.dtype
1330+
13031331
def create_khiops_dictionary(self):
13041332
"""Creates a Khiops dictionary representing this sparse table
13051333

0 commit comments

Comments
 (0)