Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/samples/samples_sklearn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ Samples
keep_initial_variables=True,
transform_type_categorical="part_id",
transform_type_numerical="part_id",
transform_pairs="part_id",
transform_type_pairs="part_id",
)
khe.fit(X, y)

Expand Down
2 changes: 1 addition & 1 deletion khiops/samples/samples_sklearn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@
" keep_initial_variables=True,\n",
" transform_type_categorical=\"part_id\",\n",
" transform_type_numerical=\"part_id\",\n",
" transform_pairs=\"part_id\",\n",
" transform_type_pairs=\"part_id\",\n",
")\n",
"khe.fit(X, y)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion khiops/samples/samples_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def khiops_encoder_with_hyperparameters():
keep_initial_variables=True,
transform_type_categorical="part_id",
transform_type_numerical="part_id",
transform_pairs="part_id",
transform_type_pairs="part_id",
)
khe.fit(X, y)

Expand Down
32 changes: 30 additions & 2 deletions khiops/sklearn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def get_khiops_variable_name(column_id):
return variable_name


def read_internal_data_table(file_path_or_stream):
def read_internal_data_table(file_path_or_stream, column_dtypes=None):
"""Reads into a DataFrame a data table file with the internal format settings

The table is read with the following settings:
Expand All @@ -357,18 +357,34 @@ def read_internal_data_table(file_path_or_stream):
- Use `csv.QUOTE_MINIMAL`
- double quoting enabled (quotes within quotes can be escaped with '""')
- UTF-8 encoding
- User-specified dtypes (optional)

Parameters
----------
file_path_or_stream : str or file object
The path of the internal data table file to be read or a readable file
object.
column_dtypes : dict, optional
Dictionary linking column names with dtypes. See ``dtype`` parameter of the
`pandas.read_csv` function. If not set, then the column types are detected
automatically by pandas.

Returns
-------
`pandas.DataFrame`
The dataframe representation.
The dataframe representation of the data table.
"""
# Change the 'U' types (Unicode strings) to 'O' because pandas does not support them
# in read_csv
if column_dtypes is not None:
execution_column_dtypes = {}
for column_name, dtype in column_dtypes.items():
if hasattr(dtype, "kind") and dtype.kind == "U":
execution_column_dtypes[column_name] = np.dtype("O")
else:
execution_column_dtypes = None

# Read and return the dataframe
return pd.read_csv(
file_path_or_stream,
sep="\t",
Expand All @@ -377,6 +393,7 @@ def read_internal_data_table(file_path_or_stream):
quoting=csv.QUOTE_MINIMAL,
doublequote=True,
encoding="utf-8",
dtype=execution_column_dtypes,
)


Expand Down Expand Up @@ -1132,6 +1149,11 @@ def __repr__(self):
f"dtypes={dtypes_str}>"
)

def get_column_dtype(self, column_id):
if column_id not in self.data_source.dtypes:
raise KeyError(f"Column '{column_id}' not found in the dtypes field")
return self.data_source.dtypes[column_id]

def create_table_file_for_khiops(
self, output_dir, sort=True, target_column=None, target_column_id=None
):
Expand Down Expand Up @@ -1214,6 +1236,9 @@ def __repr__(self):
f"dtype={dtype_str}; target={self.target_column_id}>"
)

def get_column_dtype(self, _):
return self.data_source.dtype

def create_table_file_for_khiops(
self, output_dir, sort=True, target_column=None, target_column_id=None
):
Expand Down Expand Up @@ -1300,6 +1325,9 @@ def __repr__(self):
f"dtype={dtype_str}>"
)

def get_column_dtype(self, _):
return self.data_source.dtype

def create_khiops_dictionary(self):
"""Creates a Khiops dictionary representing this sparse table

Expand Down
Loading