@@ -267,6 +267,107 @@ def test_dictionary(self):
267267 domain_copy .export_khiops_dictionary_file (copy_output_kdic )
268268 assert_files_equal (self , ref_kdic , copy_output_kdic )
269269
270+ def _build_mock_deprecated_data_path_api_method_parameters (self ):
271+ # Pseudo-mock data to test the creation of scenarios
272+ ref_additional_data_tables = {
273+ "Services" : "ServicesBidon.csv" ,
274+ "Services/Usages" : "UsagesBidon.csv" ,
275+ "Address" : "AddressBidon.csv" ,
276+ }
277+ ref_output_additional_data_tables = {
278+ "Services" : "TransferServicesBidon.csv" ,
279+ "Services/Usages" : "TransferUsagesBidon.csv" ,
280+ "Address" : "TransferAddressBidon.csv" ,
281+ }
282+ additional_data_tables = {
283+ "Customer`Services" : "ServicesBidon.csv" ,
284+ "Customer`Services`Usages" : "UsagesBidon.csv" ,
285+ "Customer`Address" : "AddressBidon.csv" ,
286+ }
287+ output_additional_data_tables = {
288+ "Customer`Services" : "TransferServicesBidon.csv" ,
289+ "Customer`Services`Usages" : "TransferUsagesBidon.csv" ,
290+ "Customer`Address" : "TransferAddressBidon.csv" ,
291+ }
292+
293+ # Store the relation method_name -> (dataset -> mock args and kwargs)
294+ method_test_args = {
295+ "check_database" : {
296+ "args" : ["Customer.kdic" , "Customer" , "Customer.csv" ],
297+ "kwargs" : {"additional_data_tables" : copy (additional_data_tables )},
298+ },
299+ # We profit to test byte strings in the deploy_model test
300+ "deploy_model" : {
301+ "args" : [
302+ bytes ("Customer.kdic" , encoding = "ascii" ),
303+ bytes ("Customer" , encoding = "ascii" ),
304+ bytes ("Customer.csv" , encoding = "ascii" ),
305+ bytes ("CustomerDeployed.csv" , encoding = "ascii" ),
306+ ],
307+ "kwargs" : {
308+ "additional_data_tables" : (
309+ {
310+ bytes (key , encoding = "ascii" ): bytes (value , encoding = "ascii" )
311+ for key , value in additional_data_tables .items ()
312+ }
313+ ),
314+ "output_additional_data_tables" : (
315+ {
316+ bytes (key , encoding = "ascii" ): bytes (value , encoding = "ascii" )
317+ for key , value in output_additional_data_tables .items ()
318+ }
319+ ),
320+ },
321+ },
322+ "evaluate_predictor" : {
323+ "args" : [
324+ "ModelingCustomer.kdic" ,
325+ "Customer" ,
326+ "Customer.csv" ,
327+ "CustomerResults/CustomerAnalysisResults.khj" ,
328+ ],
329+ "kwargs" : {"additional_data_tables" : copy (additional_data_tables )},
330+ },
331+ "train_coclustering" : {
332+ "args" : [
333+ "Customer.kdic" ,
334+ "Customer" ,
335+ "Customer.csv" ,
336+ ["id_customer" , "Name" ],
337+ "CustomerResults/CustomerCoclusteringResults._khcj" ,
338+ ],
339+ "kwargs" : {
340+ "additional_data_tables" : copy (additional_data_tables ),
341+ },
342+ },
343+ "train_predictor" : {
344+ "args" : [
345+ "Customer.kdic" ,
346+ "Customer" ,
347+ "Customer.csv" ,
348+ "" ,
349+ "CustomerResults/CustomerAnalysisResults._khj" ,
350+ ],
351+ "kwargs" : {
352+ "additional_data_tables" : copy (additional_data_tables ),
353+ },
354+ },
355+ "train_recoder" : {
356+ "args" : [
357+ "Customer.kdic" ,
358+ "Customer" ,
359+ "Customer.csv" ,
360+ "" ,
361+ "CustomerResults/CustomerAnalysisResults._khj" ,
362+ ],
363+ "kwargs" : {
364+ "additional_data_tables" : copy (additional_data_tables ),
365+ },
366+ },
367+ }
368+
369+ return method_test_args
370+
270371 def _build_mock_api_method_parameters (self ):
271372 # Pseudo-mock data to test the creation of scenarios
272373 datasets = ["Adult" , "SpliceJunction" , "Customer" ]
@@ -607,118 +708,77 @@ def test_data_path_deprecation_in_api_method(self):
607708 kh .set_runner (test_runner )
608709
609710 # Obtain mock arguments for each API call
610- method_test_args = self ._build_mock_api_method_parameters ()
611-
612- # Define legacy additional data tables path
613- str_legacy_additional_data_tables = {
614- "Customer`Services" : "ServicesBidon.csv" ,
615- "Customer`Services`Usages" : "UsagesBidon.csv" ,
616- "Customer`Address" : "AddressBidon.csv" ,
617- }
711+ method_test_args = self ._build_mock_deprecated_data_path_api_method_parameters ()
618712
619713 # Test for each dataset mock parameters
620714 for method_name , method_full_args in method_test_args .items ():
621- # Use bytes for deploy_model's additional_data_tables
622- if method_name == "deploy_model" :
623- legacy_additional_data_tables = {
624- bytes (key , encoding = "ascii" ): bytes (value , encoding = "ascii" )
625- for key , value in str_legacy_additional_data_tables .items ()
626- }
627- else :
628- legacy_additional_data_tables = str_legacy_additional_data_tables
629715 # Set the runners test name
630716 test_runner .test_name = method_name
631717
632718 # Clean the directory for this method's tests
633719 cleanup_dir (test_runner .output_scenario_dir , "*/output/*._kh" , verbose = True )
634- for dataset , dataset_method_args in method_full_args .items ():
635- # Test only for the Customer dataset
636- if dataset != "Customer" :
637- continue
638-
639- test_runner .subtest_name = dataset
640- with self .subTest (method = method_name ):
641- # Get the API function and its args and kwargs
642- method = getattr (kh , method_name )
643- dataset_args = dataset_method_args ["args" ]
644- dataset_kwargs = dataset_method_args ["kwargs" ]
645-
646- # Skip the test if `additional_data_tables` is not an
647- # API call kwarg
648- if "additional_data_tables" not in dataset_kwargs :
649- continue
650-
651- # Store current additional data_tables
652- current_additional_data_tables = copy (
653- dataset_kwargs ["additional_data_tables" ]
654- )
655-
656- # Update the `additional_data_tables` kwargs to use
657- # legacy paths
658- dataset_kwargs ["additional_data_tables" ] = copy (
659- legacy_additional_data_tables
720+ test_runner .subtest_name = "Customer"
721+ with self .subTest (method = method_name ):
722+ # Get the API function and its args and kwargs
723+ method = getattr (kh , method_name )
724+ args = method_full_args ["args" ]
725+ kwargs = method_full_args ["kwargs" ]
726+
727+ # Test that using legacy paths entails a deprecation warning
728+ with warnings .catch_warnings (record = True ) as warning_list :
729+ method (* args , ** kwargs )
730+
731+ # Check the warning message
732+ if "output_additional_data_tables" in kwargs :
733+ self .assertEqual (
734+ len (warning_list ),
735+ len (kwargs ["additional_data_tables" ])
736+ + len (kwargs ["output_additional_data_tables" ]),
660737 )
661-
662- # Test that using legacy paths entails a deprecation warning
663- with warnings .catch_warnings (record = True ) as warning_list :
664- method (* dataset_args , ** dataset_kwargs )
665-
666- # Build current-legacy data path map
667- legacy_to_current_data_paths = {}
668- for (
669- data_path ,
670- data_file_path ,
671- ) in current_additional_data_tables .items ():
672- for (
673- leg_data_path ,
674- leg_data_file_path ,
675- ) in legacy_additional_data_tables .items ():
676- if leg_data_file_path == data_file_path :
677- legacy_to_current_data_paths [leg_data_path ] = data_path
678- break
679-
680- # Check the warning message
738+ else :
681739 self .assertEqual (
682- len (warning_list ), len (legacy_additional_data_tables )
740+ len (warning_list ), len (kwargs [ "additional_data_tables" ] )
683741 )
684- warning = warning_list [0 ]
685- for warning in warning_list :
686- self .assertTrue (issubclass (warning .category , UserWarning ))
687- warning_message = warning .message
688- self .assertEqual (len (warning_message .args ), 1 )
689- message = warning_message .args [0 ]
690- self .assertTrue (
691- "'`'-based dictionary data path" in message
692- and "deprecated" in message
693- )
694-
695- # Check legacy data path is replaced with the current
696- # data path
697- for legacy_data_path in legacy_additional_data_tables :
698- expected_legacy_data_path = legacy_to_current_data_paths [
699- legacy_data_path
700- ]
701- if f"'{ legacy_data_path } '" in message :
702- self .assertTrue (
703- f"'{ expected_legacy_data_path } '" in message
704- )
705- break
706742
743+ warning = warning_list [0 ]
744+ for warning in warning_list :
745+ self .assertTrue (issubclass (warning .category , UserWarning ))
746+ warning_message = warning .message
747+ self .assertEqual (len (warning_message .args ), 1 )
748+ message = warning_message .args [0 ]
749+ self .assertTrue (
750+ "'`'-based dictionary data path" in message
751+ and "deprecated" in message
752+ )
707753 # Restore the default runner
708754 kh .set_runner (default_runner )
709755
710756 def test_unknown_argument_in_api_method (self ):
711757 """Tests if core.api raises ValueError when an unknown argument is passed"""
758+ # Set the root directory of these tests
759+ test_resources_dir = os .path .join (resources_dir (), "scenario_generation" , "api" )
760+
761+ # Use the test runner that only compares the scenarios
762+ default_runner = kh .get_runner ()
763+ test_runner = ScenarioWriterRunner (self , test_resources_dir )
764+ kh .set_runner (test_runner )
765+
712766 # Obtain mock arguments for each API call
713767 method_test_args = self ._build_mock_api_method_parameters ()
714768
715769 # Test for each dataset mock parameters
716770 for method_name , method_full_args in method_test_args .items ():
771+ # Set the runners test name
772+ test_runner .test_name = method_name
773+
774+ # Clean the directory for this method's tests
775+ cleanup_dir (test_runner .output_scenario_dir , "*/output/*._kh" , verbose = True )
717776 for dataset , dataset_method_args in method_full_args .items ():
718777 # Test only for the Adult dataset
719778 if dataset != "Adult" :
720779 continue
721780
781+ test_runner .subtest_name = dataset
722782 with self .subTest (method = method_name ):
723783 # These methods do not have kwargs so they cannot have extra args
724784 if method_name in [
@@ -742,6 +802,9 @@ def test_unknown_argument_in_api_method(self):
742802 output_msg = str (context .exception )
743803 self .assertEqual (output_msg , expected_msg )
744804
805+ # Restore the default runner
806+ kh .set_runner (default_runner )
807+
745808 def test_system_settings (self ):
746809 """Test that the system settings are written to the scenario file"""
747810 # Create the root directory of these tests
0 commit comments