Skip to content

Commit fa6e8db

Browse files
committed
Handle secondary-table data path deprecation in the Core API
More specifically, preprocess the data paths in the `additional_data_tables` and `output_additional_data_tables` arguments: - split each path on "`" - check that each data path fragment after the split is non-empty - check that the first fragment is identical to the the name of the current dictionary (as specified in `dictionary_name` or `train_dictionary_name`) If all these conditions are met, then convert the legacy data path to the new format: - drop the current dictionary name fragment from the beginning of the path - join the remaining data path fragments on "/" Note: This does not handle external tables, but for the impending beta, this seems like an acceptable tradeoff. closes #370
1 parent 977bcf9 commit fa6e8db

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

khiops/core/api.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,72 @@ def _preprocess_arguments(args):
198198
return command_line_options, system_settings, task_is_called_with_domain
199199

200200

201+
def _deprecate_legacy_data_path(data_path_task_arg_name, task_args):
202+
"""Detect and replace legacy data path with the current syntax
203+
204+
.. note:: The function mutates task_args.
205+
"""
206+
if (
207+
data_path_task_arg_name in task_args
208+
and task_args[data_path_task_arg_name] is not None
209+
):
210+
assert "dictionary_name" in task_args or "train_dictionary_name" in task_args
211+
if "dictionary_name" in task_args:
212+
current_dictionary_name = task_args["dictionary_name"]
213+
else:
214+
current_dictionary_name = task_args["train_dictionary_name"]
215+
216+
for kdic_path in task_args[data_path_task_arg_name].keys():
217+
if isinstance(kdic_path, str):
218+
deprecated_data_path_separator = "`"
219+
data_path_separator = "/"
220+
kdic_path_for_warning = kdic_path
221+
else:
222+
assert isinstance(kdic_path, bytes)
223+
deprecated_data_path_separator = b"`"
224+
data_path_separator = b"/"
225+
if isinstance(current_dictionary_name, str):
226+
current_dictionary_name = bytes(
227+
current_dictionary_name, encoding="ascii"
228+
)
229+
kdic_path_for_warning = kdic_path.decode("ascii")
230+
231+
# Path split "`" yields non-empty fragments; the first fragment
232+
# starts with the current dictionary name
233+
kdic_path_parts = kdic_path.split(deprecated_data_path_separator)
234+
if all(len(path_part) > 0 for path_part in kdic_path_parts):
235+
source_dictionary_name = kdic_path_parts[0]
236+
if source_dictionary_name == current_dictionary_name:
237+
new_kdic_path_parts = []
238+
239+
# Escape any "/" char in the path parts except for the
240+
# current dictionary, which is is skipped from the new path
241+
for kdic_path_part in kdic_path_parts[1:]:
242+
new_kdic_path_parts.append(
243+
kdic_path_part.replace(
244+
data_path_separator,
245+
deprecated_data_path_separator + data_path_separator,
246+
)
247+
)
248+
249+
# Replace the legacy data path with the current data path
250+
new_kdic_path = data_path_separator.join(new_kdic_path_parts)
251+
kdic_file_path = task_args[data_path_task_arg_name].pop(kdic_path)
252+
task_args[data_path_task_arg_name][new_kdic_path] = kdic_file_path
253+
warnings.warn(
254+
deprecation_message(
255+
"'`'-based dictionary data path: "
256+
f"'{kdic_path_for_warning}'",
257+
"11.0.1",
258+
replacement=(
259+
"'/'-based dictionary data path "
260+
f"convention: '{new_kdic_path}'"
261+
),
262+
quote=False,
263+
)
264+
)
265+
266+
201267
def _preprocess_task_arguments(task_args):
202268
"""Preprocessing of task arguments common to various tasks
203269
@@ -320,6 +386,14 @@ def _preprocess_task_arguments(task_args):
320386
)
321387
del task_args["max_variable_importances"]
322388

389+
# Detect and replace deprecated data-path syntax on additional_data_tables
390+
# Mutate task_args in the process
391+
for data_path_task_arg_name in (
392+
"additional_data_tables",
393+
"output_additional_data_tables",
394+
):
395+
_deprecate_legacy_data_path(data_path_task_arg_name, task_args)
396+
323397
# Flatten kwargs
324398
if "kwargs" in task_args:
325399
task_args.update(task_args["kwargs"])

tests/test_core.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import textwrap
1313
import unittest
1414
import warnings
15+
from copy import copy
1516
from pathlib import Path
1617
from unittest import mock
1718

@@ -593,6 +594,119 @@ def test_api_scenario_generation(self):
593594
# Restore the default runner
594595
kh.set_runner(default_runner)
595596

597+
def test_data_path_deprecation_in_api_method(self):
598+
"""Tests if core.api issues deprecation warning when legacy data
599+
paths are used for secondary tables
600+
"""
601+
# Set the root directory of these tests
602+
test_resources_dir = os.path.join(resources_dir(), "scenario_generation", "api")
603+
604+
# Use the test runner that only compares the scenarios
605+
default_runner = kh.get_runner()
606+
test_runner = ScenarioWriterRunner(self, test_resources_dir)
607+
kh.set_runner(test_runner)
608+
609+
# Obtain mock arguments for each API call
610+
method_test_args = self._build_mock_api_method_parameters()
611+
612+
# Define legacy additional data tables path
613+
str_legacy_additional_data_tables = {
614+
"Customer`Services": "ServicesBidon.csv",
615+
"Customer`Services`Usages": "UsagesBidon.csv",
616+
"Customer`Address": "AddressBidon.csv",
617+
}
618+
619+
# Test for each dataset mock parameters
620+
for method_name, method_full_args in method_test_args.items():
621+
# Use bytes for deploy_model's additional_data_tables
622+
if method_name == "deploy_model":
623+
legacy_additional_data_tables = {
624+
bytes(key, encoding="ascii"): bytes(value, encoding="ascii")
625+
for key, value in str_legacy_additional_data_tables.items()
626+
}
627+
else:
628+
legacy_additional_data_tables = str_legacy_additional_data_tables
629+
# Set the runners test name
630+
test_runner.test_name = method_name
631+
632+
# Clean the directory for this method's tests
633+
cleanup_dir(test_runner.output_scenario_dir, "*/output/*._kh", verbose=True)
634+
for dataset, dataset_method_args in method_full_args.items():
635+
# Test only for the Customer dataset
636+
if dataset != "Customer":
637+
continue
638+
639+
test_runner.subtest_name = dataset
640+
with self.subTest(method=method_name):
641+
# Get the API function and its args and kwargs
642+
method = getattr(kh, method_name)
643+
dataset_args = dataset_method_args["args"]
644+
dataset_kwargs = dataset_method_args["kwargs"]
645+
646+
# Skip the test if `additional_data_tables` is not an
647+
# API call kwarg
648+
if "additional_data_tables" not in dataset_kwargs:
649+
continue
650+
651+
# Store current additional data_tables
652+
current_additional_data_tables = copy(
653+
dataset_kwargs["additional_data_tables"]
654+
)
655+
656+
# Update the `additional_data_tables` kwargs to use
657+
# legacy paths
658+
dataset_kwargs["additional_data_tables"] = copy(
659+
legacy_additional_data_tables
660+
)
661+
662+
# Test that using legacy paths entails a deprecation warning
663+
with warnings.catch_warnings(record=True) as warning_list:
664+
method(*dataset_args, **dataset_kwargs)
665+
666+
# Build current-legacy data path map
667+
legacy_to_current_data_paths = {}
668+
for (
669+
data_path,
670+
data_file_path,
671+
) in current_additional_data_tables.items():
672+
for (
673+
leg_data_path,
674+
leg_data_file_path,
675+
) in legacy_additional_data_tables.items():
676+
if leg_data_file_path == data_file_path:
677+
legacy_to_current_data_paths[leg_data_path] = data_path
678+
break
679+
680+
# Check the warning message
681+
self.assertEqual(
682+
len(warning_list), len(legacy_additional_data_tables)
683+
)
684+
warning = warning_list[0]
685+
for warning in warning_list:
686+
self.assertTrue(issubclass(warning.category, UserWarning))
687+
warning_message = warning.message
688+
self.assertEqual(len(warning_message.args), 1)
689+
message = warning_message.args[0]
690+
self.assertTrue(
691+
"'`'-based dictionary data path" in message
692+
and "deprecated" in message
693+
)
694+
695+
# Check legacy data path is replaced with the current
696+
# data path
697+
for legacy_data_path in legacy_additional_data_tables:
698+
expected_legacy_data_path = legacy_to_current_data_paths[
699+
legacy_data_path
700+
]
701+
if f"'{legacy_data_path}'" in message:
702+
self.assertTrue(
703+
f"'{expected_legacy_data_path}'" in message
704+
)
705+
break
706+
707+
# Restore the default runner
708+
kh.set_runner(default_runner)
709+
596710
def test_unknown_argument_in_api_method(self):
597711
"""Tests if core.api raises ValueError when an unknown argument is passed"""
598712
# Obtain mock arguments for each API call

0 commit comments

Comments
 (0)