3737 DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML ,
3838 DATA_CONFIG_MULTITURN_DATA_YAML ,
3939 DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML ,
40- DATA_CONFIG_RENAME_RETAIN_COLUMNS ,
40+ DATA_CONFIG_RENAME_SELECT_COLUMNS ,
4141 DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML ,
4242 DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT ,
4343 DATA_CONFIG_YAML_STREAMING_PRETOKENIZED ,
7070from tuning .config import configs
7171from tuning .config .acceleration_configs import AttentionAndDistributedPackingConfig
7272from tuning .data .collators import VisionDataCollator
73- from tuning .data .data_config import DataPreProcessorConfig , DataSetConfig
73+ from tuning .data .data_config import (
74+ DataHandlerConfig ,
75+ DataPreProcessorConfig ,
76+ DataSetConfig ,
77+ )
7478from tuning .data .data_preprocessing_utils import get_data_collator
7579from tuning .data .data_processors import DataPreProcessor , get_datapreprocessor
7680from tuning .data .setup_dataprocessor import (
@@ -1674,33 +1678,33 @@ def test_process_dataset_configs_with_sampling_error(
16741678
16751679
16761680@pytest .mark .parametrize (
1677- "datafile, rename, retain , final, datasetconfigname" ,
1681+ "datafile, rename, select , final, datasetconfigname" ,
16781682 [
16791683 (
16801684 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
16811685 {"input" : "instruction" , "output" : "response" },
16821686 None ,
16831687 ["ID" , "Label" , "instruction" , "response" ],
1684- DATA_CONFIG_RENAME_RETAIN_COLUMNS ,
1688+ DATA_CONFIG_RENAME_SELECT_COLUMNS ,
16851689 ),
16861690 (
16871691 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
16881692 None ,
16891693 ["ID" , "input" , "output" ],
16901694 ["ID" , "input" , "output" ],
1691- DATA_CONFIG_RENAME_RETAIN_COLUMNS ,
1695+ DATA_CONFIG_RENAME_SELECT_COLUMNS ,
16921696 ),
16931697 (
16941698 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
16951699 {"input" : "instruction" , "output" : "response" },
16961700 ["Label" , "instruction" , "response" ],
16971701 ["Label" , "instruction" , "response" ],
1698- DATA_CONFIG_RENAME_RETAIN_COLUMNS ,
1702+ DATA_CONFIG_RENAME_SELECT_COLUMNS ,
16991703 ),
17001704 ],
17011705)
1702- def test_rename_and_retain_dataset_columns (
1703- datafile , rename , retain , final , datasetconfigname
1706+ def test_rename_and_select_dataset_columns (
1707+ datafile , rename , select , final , datasetconfigname
17041708):
17051709 """Test process_dataset_configs for expected output."""
17061710 dataprocessor_config = DataPreProcessorConfig ()
@@ -1709,12 +1713,23 @@ def test_rename_and_retain_dataset_columns(
17091713 processor_config = dataprocessor_config ,
17101714 tokenizer = tokenizer ,
17111715 )
1716+
1717+ handlers = []
1718+ if rename :
1719+ handlers .append (
1720+ DataHandlerConfig (
1721+ name = "rename_columns" , arguments = {"column_mapping" : rename }
1722+ )
1723+ )
1724+ if select :
1725+ handlers .append (
1726+ DataHandlerConfig (name = "select_columns" , arguments = {"column_names" : select })
1727+ )
1728+ data_paths = [datafile ]
1729+
17121730 datasetconfig = [
17131731 DataSetConfig (
1714- name = datasetconfigname ,
1715- data_paths = [datafile ],
1716- rename_columns = rename ,
1717- retain_columns = retain ,
1732+ name = datasetconfigname , data_paths = data_paths , data_handlers = handlers
17181733 )
17191734 ]
17201735 train_dataset = processor .process_dataset_configs (dataset_configs = datasetconfig )
0 commit comments